def train_generator(
        model_G,
        model_D,
        trainloader,
        optimizer_G,
        train_dataset_size,
        device
):
    """
    Train generator (segmentation network), including loss_ce and loss_adv with GT.
    :return: loss_ce, loss_adv
    """
    loss_ce_value = []
    loss_adv_value = []
    NUM_BATCHES = np.floor(train_dataset_size/BATCH_SIZE)

    for i, mini_batch in tqdm.tqdm(enumerate(trainloader), total=NUM_BATCHES):
        # don't accumulate grads in D
        for param in model_G.parameters():
            param.requires_grad = True
        for param in model_D.parameters():
            param.requires_grad = False

        optimizer_G.zero_grad()

        points, cls_gt, seg_gt = mini_batch
        points, cls_gt, seg_gt = Variable(points).float(), \
                                 Variable(cls_gt).float(), \
                                 Variable(seg_gt).type(torch.LongTensor)
        points, cls_gt, seg_gt = points.to(device), \
                                 cls_gt.to(device), \
                                 seg_gt.to(device)
        pred = model_G(points, cls_gt)
        # loss_ce
        loss_ce = loss_calc(pred, seg_gt, device, mask=False)
        # loss_adv
        D_out = model_D(F.softmax(pred, dim=2))
        ignore_mask = np.zeros(seg_gt.shape).astype(np.bool)
        loss_adv = loss_bce(D_out, make_D_label(GT_LABEL, ignore_mask, device), device)
        loss_seg = loss_ce + LAMBDA_ADV * loss_adv
        loss_seg.backward()
        optimizer_G.step()

        loss_ce_value.append(loss_ce.item())
        loss_adv_value.append(loss_adv.item())

    return np.average(loss_ce_value), np.average(loss_adv_value)
    def train(self, src_loader, tar_loader, val_loader):

        loss_rot = loss_adv = loss_weight = loss_D_s = loss_D_t = 0
        args = self.args
        log = self.logger
        device = self.device

        interp_source = nn.Upsample(size=(args.datasets.source.images_size[1],
                                          args.datasets.source.images_size[0]),
                                    mode='bilinear',
                                    align_corners=True)
        interp_target = nn.Upsample(size=(args.datasets.target.images_size[1],
                                          args.datasets.target.images_size[0]),
                                    mode='bilinear',
                                    align_corners=True)
        interp_prediction = nn.Upsample(size=(args.auxiliary.images_size[1],
                                              args.auxiliary.images_size[0]),
                                        mode='bilinear',
                                        align_corners=True)

        source_iter = enumerate(src_loader)
        target_iter = enumerate(tar_loader)

        self.model.train()
        self.model = self.model.to(device)

        if args.method.adversarial:
            self.model_D.train()
            self.model_D = self.model_D.to(device)

        if args.method.self:
            self.model_A.train()
            self.model_A = self.model_A.to(device)

        log.info('###########   TRAINING STARTED  ############')
        start = time.time()

        for i_iter in range(self.start_iter, self.num_steps):

            self.model.train()
            self.optimizer.zero_grad()
            adjust_learning_rate(self.optimizer, self.preheat, args.num_steps,
                                 args.power, i_iter, args.model.optimizer)

            if args.method.adversarial:
                self.model_D.train()
                self.optimizer_D.zero_grad()
                adjust_learning_rate(self.optimizer_D, self.preheat,
                                     args.num_steps, args.power, i_iter,
                                     args.discriminator.optimizer)

            if args.method.self:
                self.model_A.train()
                self.optimizer_A.zero_grad()
                adjust_learning_rate(self.optimizer_A, self.preheat,
                                     args.num_steps, args.power, i_iter,
                                     args.auxiliary.optimizer)

            damping = (1 - i_iter / self.num_steps
                       )  # similar to early stopping

            # ======================================================================================
            # train G
            # ======================================================================================
            if args.method.adversarial:
                for param in self.model_D.parameters():  # Remove Grads in D
                    param.requires_grad = False

            # Train with Source
            _, batch = next(source_iter)
            images_s, labels_s, _, _ = batch
            images_s = images_s.to(device)
            pred_source1_, pred_source2_ = self.model(images_s)

            pred_source1 = interp_source(pred_source1_)
            pred_source2 = interp_source(pred_source2_)

            # Segmentation Loss
            loss_seg = (
                loss_calc(self.num_classes, pred_source1, labels_s, device) +
                loss_calc(self.num_classes, pred_source2, labels_s, device))
            loss_seg.backward()
            self.losses['seg'].append(loss_seg.item())

            # Train with Target
            _, batch = next(target_iter)
            images_t, labels_t = batch
            images_t = images_t.to(device)
            pred_target1_, pred_target2_ = self.model(images_t)

            pred_target1 = interp_target(pred_target1_)
            pred_target2 = interp_target(pred_target2_)

            # Semi-supervised approach
            if args.use_target_labels and i_iter % int(
                    1 / args.target_frac) == 0:
                loss_seg_t = (loss_calc(args.num_classes, pred_target1,
                                        labels_t, device) +
                              loss_calc(args.num_classes, pred_target2,
                                        labels_t, device))
                loss_seg_t.backward()
                self.losses['seg_t'].append(loss_seg_t.item())

            # Adversarial Loss
            if args.method.adversarial:

                pred_target1 = pred_target1.detach()
                pred_target2 = pred_target2.detach()

                weight_map = weightmap(F.softmax(pred_target1, dim=1),
                                       F.softmax(pred_target2, dim=1))

                D_out = interp_target(
                    self.model_D(F.softmax(pred_target1 + pred_target2,
                                           dim=1)))

                # Adaptive Adversarial Loss
                if i_iter > self.preheat:
                    loss_adv = self.weighted_bce_loss(
                        D_out,
                        torch.FloatTensor(D_out.data.size()).fill_(
                            self.source_label).to(device), weight_map,
                        args.Epsilon, args.Lambda_local)
                else:
                    loss_adv = self.bce_loss(
                        D_out,
                        torch.FloatTensor(D_out.data.size()).fill_(
                            self.source_label).to(device))

                loss_adv.requires_grad = True
                loss_adv = loss_adv * self.args.Lambda_adv * damping
                loss_adv.backward()
                self.losses['adv'].append(loss_adv.item())

        # Weight Discrepancy Loss
            if args.weight_loss:
                W5 = None
                W6 = None
                if args.model.name == 'DeepLab':  # TODO: ADD ERF-NET

                    for (w5, w6) in zip(self.model.layer5.parameters(),
                                        self.model.layer6.parameters()):
                        if W5 is None and W6 is None:
                            W5 = w5.view(-1)
                            W6 = w6.view(-1)
                        else:
                            W5 = torch.cat((W5, w5.view(-1)), 0)
                            W6 = torch.cat((W6, w6.view(-1)), 0)

                loss_weight = (torch.matmul(W5, W6) /
                               (torch.norm(W5) * torch.norm(W6)) + 1
                               )  # +1 is for a positive loss
                loss_weight = loss_weight * args.Lambda_weight * damping * 2
                loss_weight.backward()
                self.losses['weight'].append(loss_weight.item())

        # ======================================================================================
        # train D
        # ======================================================================================
            if args.method.adversarial:
                # Bring back Grads in D
                for param in self.model_D.parameters():
                    param.requires_grad = True

                # Train with Source
                pred_source1 = pred_source1.detach()
                pred_source2 = pred_source2.detach()

                D_out_s = interp_source(
                    self.model_D(F.softmax(pred_source1 + pred_source2,
                                           dim=1)))
                loss_D_s = self.bce_loss(
                    D_out_s,
                    torch.FloatTensor(D_out_s.data.size()).fill_(
                        self.source_label).to(device))
                loss_D_s.backward()
                self.losses['ds'].append(loss_D_s.item())

                # Train with Target
                pred_target1 = pred_target1.detach()
                pred_target2 = pred_target2.detach()
                weight_map = weight_map.detach()

                D_out_t = interp_target(
                    self.model_D(F.softmax(pred_target1 + pred_target2,
                                           dim=1)))

                # Adaptive Adversarial Loss
                if i_iter > self.preheat:
                    loss_D_t = self.weighted_bce_loss(
                        D_out_t,
                        torch.FloatTensor(D_out_t.data.size()).fill_(
                            self.target_label).to(device), weight_map,
                        args.Epsilon, args.Lambda_local)
                else:
                    loss_D_t = self.bce_loss(
                        D_out_t,
                        torch.FloatTensor(D_out_t.data.size()).fill_(
                            self.target_label).to(device))

                loss_D_t.backward()
                self.losses['dt'].append(loss_D_t.item())

        # ======================================================================================
        # Train SELF SUPERVISED TASK
        # ======================================================================================
            if args.method.self:
                ''' SELF-SUPERVISED (ROTATION) ALGORITHM 
                - Get squared prediction 
                - Rotate it randomly (0,90,180,270) -> assign self-label (0,1,2,3)  [*2 IF WANT TO CLASSIFY ALSO S/T]
                - Send rotated prediction to the classifier
                - Get loss 
                - Update weights of classifier and G (segmentation network) 
                '''

                # Train with Source
                pred_source1 = pred_source1_.detach()
                pred_source2 = pred_source2_.detach()

                # Train with Target
                pred_target1 = pred_target1_.detach()
                pred_target2 = pred_target2_.detach()

                # pred_source = interp_prediction(F.softmax(pred_source1 + pred_source2, dim=1))
                pred_target = interp_prediction(
                    F.softmax(pred_target1 + pred_target2, dim=1))

                # ROTATE TENSORS
                # source
                label_source = torch.empty(1, dtype=torch.long).random_(
                    args.auxiliary.aux_classes / 2).to(device)
                rotated_pred_source = rotate_tensor(
                    pred_source, self.rotations[label_source.item()])
                pred_source_label = self.model_A(rotated_pred_source)
                loss_rot_source = self.aux_loss(pred_source_label,
                                                label_source)

                # target
                label_target = torch.empty(1, dtype=torch.long).random_(
                    args.auxiliary.aux_classes).to(device)
                rotated_pred_target = rotate_tensor(
                    pred_target, self.rotations[label_target.item()])
                pred_target_label = self.model_A(rotated_pred_target)
                loss_rot_target = self.aux_loss(pred_target_label,
                                                label_target)

                loss_rot = (loss_rot_source +
                            loss_rot_target) * args.Lambda_aux
                #loss_rot = loss_rot_target * args.Lambda_aux
                loss_rot.backward()
                self.losses['aux'].append(loss_rot.item())

            # Optimizers steps
            self.optimizer.step()
            if args.method.adversarial:
                self.optimizer_D.step()
            if args.method.self:
                self.optimizer_A.step()

            if i_iter % 10 == 0:
                log.info(
                    'Iter = {0:6d}/{1:6d}, loss_seg = {2:.4f} loss_rot = {3:.4f}, loss_adv = {4:.4f}, loss_weight = {5:.4f}, loss_D_s = {6:.4f} loss_D_t = {7:.4f}'
                    .format(i_iter, self.num_steps, loss_seg, loss_rot,
                            loss_adv, loss_weight, loss_D_s, loss_D_t))

            if (i_iter % args.save_pred_every == 0
                    and i_iter != 0) or i_iter == self.num_steps - 1:
                log.info('saving weights...')
                i_iter = i_iter if i_iter != self.num_steps - 1 else i_iter + 1  # for last iter
                torch.save(
                    self.model.state_dict(),
                    join(args.snapshot_dir, 'GTA5_' + str(i_iter) + '.pth'))
                if args.method.adversarial:
                    torch.save(
                        self.model_D.state_dict(),
                        join(args.snapshot_dir,
                             'GTA5_' + str(i_iter) + '_D.pth'))
                if args.method.self:
                    torch.save(
                        self.model_A.state_dict(),
                        join(args.snapshot_dir,
                             'GTA5_' + str(i_iter) + '_Aux.pth'))

                self.validate(i_iter, val_loader)
                compute_mIoU(i_iter, args.datasets.target.val.label_dir,
                             self.save_path, args.datasets.target.json_file,
                             args.datasets.target.base_list, args.results_dir)
                save_losses_plot(args.results_dir, self.losses)

                #SAVE ALSO IMAGES OF SOURCE AND TARGET
                save_segmentations(args.images_dir, images_s, labels_s,
                                   pred_source1, images_t)

            del images_s, labels_s, pred_source1, pred_source2, pred_source1_, pred_source2_
            del images_t, labels_t, pred_target1, pred_target2, pred_target1_, pred_target2_

        end = time.time()
        days = int((end - start) / 86400)
        log.info(
            'Total training time: {} days, {} hours, {} min, {} sec '.format(
                days,
                int((end - start) / 3600) - (days * 24),
                int((end - start) / 60 % 60), int((end - start) % 60)))
        print('### Experiment: ' + args.experiment + ' finished ###')
def train_semi(
        model_G,
        model_D,
        trainloader_remain,
        optimizer_G,
        train_dataset_size,
        device
):
    """
    Train when GT is NOT available, needs to train loss_semi_adv and loss_semi (for generator).
    :return: loss_semi_value = loss_semi_adv + loss_semi
    """
    loss_semi_adv_value = []
    loss_semi_value = []
    NUM_BATCHES = np.floor(train_dataset_size / BATCH_SIZE)
    for i, mini_batch in tqdm.tqdm(enumerate(trainloader_remain), total = NUM_BATCHES):
        # don't accumulate grads in D
        for param in model_G.parameters():
            param.requires_grad = True
        for param in model_D.parameters():
            param.requires_grad = False

        optimizer_G.zero_grad()

        # only access to points
        points, cls, _ = mini_batch
        points, cls = Variable(points).float(), Variable(cls).float()
        points, cls = points.to(device), cls.to(device)

        pred = model_G(points, cls)
        pred_remain = pred.detach()

        D_out = model_D(F.softmax(pred, dim=2))
        D_out_sigmoid = torch.sigmoid(D_out).data.cpu().numpy()  # BxN
        ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype(np.bool)  # Bx2048

        ### semi_adv ###
        loss_semi_adv = LAMBDA_SEMI_ADV * loss_bce(D_out, make_D_label(
            GT_LABEL,
            ignore_mask_remain,
            device),
            device
        )

        ### semi ###
        semi_ignore_mask = (D_out_sigmoid < MASK_T)
        semi_gt = pred.data.cpu().numpy().argmax(axis=2)
        semi_gt[semi_ignore_mask] = 999

        semi_ratio = 1.0 - float(semi_ignore_mask.sum()) / semi_ignore_mask.size
        print('semi ratio: {:.4f}'.format(semi_ratio))

        if semi_ratio == 0.0:
            raise ValueError("Semi ratio == 0!")
        else:
            semi_gt = torch.FloatTensor(semi_gt)
            loss_semi = LAMBDA_SEMI * loss_calc(pred, semi_gt, device, mask=True)
            loss_semi_value.append(loss_semi.item())
            loss_semi += loss_semi_adv
            loss_semi.backward()
            optimizer_G.step()
            loss_semi_adv_value.append(loss_semi_adv.item())

    return pred_remain, ignore_mask_remain, np.average(loss_semi_adv_value), np.average(loss_semi_value)
Example #4
0
    def train(self, tar_loader, val_loader):

        loss_weight = 0
        args = self.args
        log = self.logger
        device = self.device

        interp_target = nn.Upsample(size=(args.datasets.target.images_size[1],
                                          args.datasets.target.images_size[0]),
                                    mode='bilinear',
                                    align_corners=True)

        target_iter = enumerate(tar_loader)

        self.model.train()
        self.model = self.model.to(device)

        log.info('###########   TRAINING STARTED  ############')
        start = time.time()

        for i_iter in range(self.start_iter, self.num_steps):

            if i_iter % int(1 / args.target_frac) == 0:
                self.model.train()
                self.optimizer.zero_grad()
                adjust_learning_rate(self.optimizer, self.preheat,
                                     args.num_steps, args.power, i_iter,
                                     args.model.optimizer)

                damping = (1 - i_iter / self.num_steps
                           )  # similar to early stopping

                # Train with Target
                _, batch = next(target_iter)
                images_t, labels_t = batch
                images_t = images_t.to(device)
                pred_target1, pred_target2 = self.model(images_t)

                pred_target1 = interp_target(pred_target1)
                pred_target2 = interp_target(pred_target2)

                loss_seg_t = (loss_calc(args.num_classes, pred_target1,
                                        labels_t, device) +
                              loss_calc(args.num_classes, pred_target2,
                                        labels_t, device))
                loss_seg_t.backward()
                self.losses['seg_t'].append(loss_seg_t.item())

                # Weight Discrepancy Loss
                if args.weight_loss:

                    W5 = None
                    W6 = None
                    # TODO: ADD ERF-NET
                    if args.model.name == 'DeepLab':

                        for (w5, w6) in zip(self.model.layer5.parameters(),
                                            self.model.layer6.parameters()):
                            if W5 is None and W6 is None:
                                W5 = w5.view(-1)
                                W6 = w6.view(-1)
                            else:
                                W5 = torch.cat((W5, w5.view(-1)), 0)
                                W6 = torch.cat((W6, w6.view(-1)), 0)

                    # Cosine distance between W5 and W6 vectors
                    loss_weight = (torch.matmul(W5, W6) /
                                   (torch.norm(W5) * torch.norm(W6)) + 1
                                   )  # +1 is for a positive loss
                    loss_weight = loss_weight * args.Lambda_weight * damping * 2
                    loss_weight.backward()
                    self.losses['weight'].append(loss_weight.item())

                # Optimizers steps
                self.optimizer.step()

                if i_iter % 10 == 0:
                    log.info(
                        'Iter = {0:6d}/{1:6d}, loss_seg = {2:.4f}, loss_weight = {3:.4f}'
                        .format(i_iter, self.num_steps, loss_seg_t,
                                loss_weight))

                if (i_iter % args.save_pred_every == 0
                        and i_iter != 0) or i_iter == self.num_steps - 1:
                    log.info('saving weights...')
                    i_iter = i_iter if i_iter != self.num_steps - 1 else i_iter + 1  # for last iter
                    torch.save(
                        self.model.state_dict(),
                        join(args.snapshot_dir,
                             'GTA5_' + str(i_iter) + '.pth'))

                    self.validate(i_iter, val_loader)
                    compute_mIoU(i_iter, args.datasets.target.val.label_dir,
                                 self.save_path,
                                 args.datasets.target.json_file,
                                 args.datasets.target.base_list,
                                 args.results_dir)
                    #save_losses_plot(args.results_dir, self.losses)

                    # SAVE ALSO IMAGES OF SOURCE AND TARGET
                    #save_segmentations(args.images_dir, images_s, labels_s, pred_source1, images_t)

                del images_t, labels_t, pred_target1, pred_target2

        end = time.time()
        days = int((end - start) / 86400)
        log.info(
            'Total training time: {} days, {} hours, {} min, {} sec '.format(
                days,
                int((end - start) / 3600) - (days * 24),
                int((end - start) / 60 % 60), int((end - start) % 60)))
        print('### Experiment: ' + args.experiment + ' finished ###')
Example #5
0
def train(opts):
    if not os.path.exists(opts.snapshot_dir):
        os.makedirs(opts.snapshot_dir)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_G = create_model(type="generator",
                           num_seg_classes=opts.num_seg_classes)
    model_D = create_model(type="discriminator",
                           num_seg_classes=opts.num_seg_classes)

    model_G.to(device)
    model_G.train()

    model_D.to(device)
    model_D.train()

    train_dataset = create_dataset(
        num_inst_classes=NUM_INST_CLASSES,
        num_pts=NUM_PTS,
        mode="train",
        is_noise=IS_NOISE,
        is_rotate=IS_ROTATE,
    )
    train_dataset_size = len(train_dataset)
    print("#Total train: {:6d}".format(train_dataset_size))

    train_gt_dataset = create_GT_dataset(
        num_inst_classes=NUM_INST_CLASSES,
        num_pts=NUM_PTS,
    )

    if opts.partial_data is None:
        trainloader = create_dataloader(
            dataset=train_dataset,
            batch_size=BATCH_SIZE,
            num_workers=NUM_WORKERS,
            shuffle=IS_SHUFFLE,
            pin_memory=True,
        )
        trainloader_gt = create_dataloader(
            dataset=train_gt_dataset,
            batch_size=BATCH_SIZE,
            num_workers=NUM_WORKERS,
            shuffle=IS_SHUFFLE,
            pin_memory=True,
        )
        trainloader_iter = iter(trainloader)
        trainloader_gt_iter = iter(trainloader_gt)
    else:
        partial_size = int(opts.partial_data * train_dataset_size)

        if opts.partial_id is not None:
            train_ids = pickle.load(open(opts.partial_id))
            print('loading train ids from {}'.format(opts.partial_id))
        else:
            train_ids = list(range(train_dataset_size))
            np.random.shuffle(train_ids)

        pickle.dump(
            train_ids,
            open(os.path.join(opts.snapshot_dir, 'train_id.pkl'), 'wb'))

        train_sampler = torch.utils.data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])
        train_remain_sampler = torch.utils.data.sampler.SubsetRandomSampler(
            train_ids[partial_size:])
        train_gt_sampler = torch.utils.data.sampler.SubsetRandomSampler(
            train_ids[:partial_size])

        trainloader = create_dataloader(
            dataset=train_dataset,
            batch_size=BATCH_SIZE,
            num_workers=NUM_WORKERS,
            shuffle=IS_SHUFFLE,
            pin_memory=True,
            sampler=train_sampler,
        )
        trainloader_gt = create_dataloader(
            dataset=train_gt_dataset,
            batch_size=BATCH_SIZE,
            num_workers=NUM_WORKERS,
            shuffle=IS_SHUFFLE,
            pin_memory=True,
            sampler=train_gt_sampler,
        )
        trainloader_remain = create_dataloader(
            dataset=train_gt_dataset,
            batch_size=BATCH_SIZE,
            num_workers=NUM_WORKERS,
            shuffle=IS_SHUFFLE,
            pin_memory=True,
            sampler=train_remain_sampler,
        )
        trainloader_remain_iter = iter(trainloader_remain)
        trainloader_iter = iter(trainloader)
        trainloader_gt_iter = iter(trainloader_gt)

    # optimizer for segmentation network
    optimizer = optim.Adam(model_G.parameters(), lr=opts.lr_G)
    optimizer.zero_grad()

    # optimizer for discriminator network
    optimizer_D = optim.Adam(model_D.parameters(),
                             lr=opts.lr_G,
                             betas=(0.9, 0.999))
    optimizer_D.zero_grad()

    # labels for adversarial training
    pred_label = 0
    gt_label = 1

    i_iter = 0
    for epoch in np.arange(NUM_EPOCHS):
        loss_ce_value = 0
        loss_adv_value = 0
        loss_D_value = 0
        loss_semi_value = 0
        loss_semi_adv_value = 0

        optimizer.zero_grad()
        adjust_learning_rate(optimizer, i_iter, LR_G,
                             NUM_EPOCHS * train_dataset_size / (BATCH_SIZE),
                             POWER)
        optimizer_D.zero_grad()
        adjust_learning_rate_D(optimizer_D, i_iter, LR_D,
                               NUM_EPOCHS * train_dataset_size / (BATCH_SIZE),
                               POWER)

        if epoch >= 0 and epoch <= 9:  # only train generator
            for i, mini_batch in enumerate(trainloader):
                # don't accumulate grads in D
                for param in model_D.parameters():
                    param.requires_grad = False

                points, cls_gt, seg_gt = mini_batch
                points, cls_gt, seg_gt = Variable(points).float(), \
                                         Variable(cls_gt).float(), \
                                         Variable(seg_gt).type(torch.LongTensor)
                points, cls_gt, seg_gt = points.to(device), \
                                         cls_gt.to(device), \
                                         seg_gt.to(device)
                pred = model_G(points, cls_gt)
                # loss_ce
                loss_ce = loss_calc(pred, seg_gt, device, mask=False)
                # loss_adv
                D_out = model_D(F.softmax(pred, dim=2))
                ignore_mask = np.zeros(seg_gt.shape).astype(np.bool)
                loss_adv = loss_bce(
                    D_out, make_D_label(gt_label, ignore_mask, device), device)
                loss_seg = loss_ce + LAMBDA_ADV * loss_adv
                loss_seg.backward()
                loss_ce_value += loss_ce.item()
                loss_adv_value += loss_adv.item()

            fastprint('[%d/%d] CE loss: %.3f, ADV loss: %.3f' %
                      (epoch, NUM_EPOCHS, loss_ce_value, loss_adv_value))
        elif epoch >= 10 and epoch <= 19:  # only train discriminator
            for i, mini_batch in enumerate(trainloader):
                # don't accumulate grads in G
                for param in model_G.parameters():
                    param.requires_grad = False
                for param in model_D.parameters():
                    param.requires_grad = True

                points, cls_gt, seg_gt = mini_batch
                points, cls_gt, seg_gt = Variable(points).float(), \
                                         Variable(cls_gt).float(), \
                                         Variable(seg_gt).type(torch.LongTensor)
                points, cls_gt, seg_gt = points.to(device), \
                                         cls_gt.to(device), \
                                         seg_gt.to(device)

                ignore_mask_gt = np.zeros(seg_gt.shape).astype(np.bool)
                D_gt_v = Variable(one_hot(seg_gt,
                                          NUM_SEG_CLASSES)).float().to(device)
                D_out = model_D(D_gt_v)
                loss_D_gt = loss_bce(
                    D_out, make_D_label(gt_label, ignore_mask_gt, device),
                    device)

                ignore_mask = np.zeros(seg_gt.shape).astype(np.bool)
                pred = model_G(points, cls_gt)
                pred = pred.detach()
                D_out = model_D(F.softmax(pred, dim=2))
                loss_D_pred = loss_bce(
                    D_out, make_D_label(pred_label, ignore_mask, device),
                    device)
                loss_D = loss_D_gt + loss_D_pred
                loss_D.backward()
                loss_D_value += loss_D.item()
        else:  # start unlabeled data
            for i, mini_batch in enumerate(trainloader_remain):
                # don't accumulate grads in D
                for param in model_D.parameters():
                    param.requires_grad = False

                # only access to img
                points, cls, _ = mini_batch
                points, cls = Variable(points).float(), Variable(cls).float()
                points, cls = points.to(device), cls.to(device)

                pred = model_G(points, cls)  # BxNxC
                pred_remain = pred.detach()

                D_out = model_D(F.softmax(pred, dim=2))
                D_out_sigmoid = torch.sigmoid(D_out).data.cpu().numpy()  # BxN
                ignore_mask_remain = np.zeros(D_out_sigmoid.shape).astype(
                    np.bool)  # Bx2048

                ### semi_adv ###
                loss_semi_adv = LAMBDA_SEMI_ADV * loss_bce(
                    D_out, make_D_label(gt_label, ignore_mask_remain, device),
                    device)

                ### semi ###
                semi_ignore_mask = (D_out_sigmoid < MASK_T)
                semi_gt = pred.data.cpu().numpy().argmax(axis=2)
                semi_gt[semi_ignore_mask] = 999

                semi_ratio = 1.0 - float(
                    semi_ignore_mask.sum()) / semi_ignore_mask.size
                print('semi ratio: {:.4f}'.format(semi_ratio))

                if semi_ratio == 0.0:
                    loss_semi_value += 0
                    raise ValueError("Semi ratio == 0!")
                else:
                    semi_gt = torch.FloatTensor(semi_gt)
                    loss_semi = LAMBDA_SEMI * loss_calc(
                        pred, semi_gt, device, mask=True)
                    loss_semi += loss_semi_adv
                    loss_semi.backward()
                    loss_semi_adv_value += loss_semi_adv.item()
                    loss_semi_value += loss_semi.item()