Exemple #1
0
    def __init__(self, model, cfg):
        '''model(nn.Layer): A sementic segmentation model.'''
        self.cfg = cfg
        self.ema_decay = cfg['ema_decay']
        self.edgeconstrain = cfg['edgeconstrain']
        self.edgepullin = cfg['edgepullin']
        self.src_only = cfg['src_only']
        self.featurepullin = cfg['featurepullin']

        self.model = model
        self.ema = EMA(self.model, self.ema_decay)
        self.celoss = losses.CrossEntropyLoss()
        self.klloss = losses.KLLoss()
        self.mseloss = losses.MSELoss()
        self.bceloss_src = losses.BCELoss(weight='dynamic')
        self.bceloss_tgt = losses.BCELoss(weight='dynamic')
        self.src_centers = [paddle.zeros((1, 19)) for _ in range(19)]
        self.tgt_centers = [paddle.zeros((1, 19)) for _ in range(19)]

        if 'None' in cfg['resume_ema']:
            self.resume_ema = None
        else:
            self.resume_ema = cfg['resume_ema']
        gc.collect()
        #for i in range(6):
        #    plt.imshow(data[-1]['images'][i])
        #    plt.show()

        ##### If fix wavelet #######
        #if wavelet_opt <= 0:
        #    pht.freeze_persistence = True
        #    eval_model.freeze_persistence = True
        #    for name, param in pht.named_parameters():
        #        if 'rbfweights' in name:
        #            param.requires_grad = False
        #        else:
        #            param.requires_grad = True

        ema = EMA(ema_decay)
        ema.register(pht.state_dict(keep_vars=False))

        if class_weighting:
            criterion = nn.BCEWithLogitsLoss(
                pos_weight=pos_weight)  #loss function
        else:
            criterion = nn.BCEWithLogitsLoss()  #loss function

        optimizer_classifier = optim.Adam([{
            'params': [
                param for name, param in pht.named_parameters()
                if 'rbfweights' not in name
            ]
        }],
                                          lr=lr_cl,
Exemple #3
0
class Trainer():
    def __init__(self, model, cfg):
        '''model(nn.Layer): A sementic segmentation model.'''
        self.cfg = cfg
        self.ema_decay = cfg['ema_decay']
        self.edgeconstrain = cfg['edgeconstrain']
        self.edgepullin = cfg['edgepullin']
        self.src_only = cfg['src_only']
        self.featurepullin = cfg['featurepullin']

        self.model = model
        self.ema = EMA(self.model, self.ema_decay)
        self.celoss = losses.CrossEntropyLoss()
        self.klloss = losses.KLLoss()
        self.mseloss = losses.MSELoss()
        self.bceloss_src = losses.BCELoss(weight='dynamic')
        self.bceloss_tgt = losses.BCELoss(weight='dynamic')
        self.src_centers = [paddle.zeros((1, 19)) for _ in range(19)]
        self.tgt_centers = [paddle.zeros((1, 19)) for _ in range(19)]

        if 'None' in cfg['resume_ema']:
            self.resume_ema = None
        else:
            self.resume_ema = cfg['resume_ema']

    def train(self,
              train_dataset_src,
              train_dataset_tgt,
              val_dataset_tgt=None,
              val_dataset_src=None,
              optimizer=None,
              save_dir='output',
              iters=10000,
              batch_size=2,
              resume_model=None,
              save_interval=1000,
              log_iters=10,
              num_workers=0,
              use_vdl=False,
              keep_checkpoint_max=5,
              test_config=None):
        """
        Launch training.

        Args:
            train_dataset (paddle.io.Dataset): Used to read and process training datasets.
            val_dataset_tgt (paddle.io.Dataset, optional): Used to read and process validation datasets.
            optimizer (paddle.optimizer.Optimizer): The optimizer.
            save_dir (str, optional): The directory for saving the model snapshot. Default: 'output'.
            iters (int, optional): How may iters to train the model. Defualt: 10000.
            batch_size (int, optional): Mini batch size of one gpu or cpu. Default: 2.
            resume_model (str, optional): The path of resume model.
            save_interval (int, optional): How many iters to save a model snapshot once during training. Default: 1000.
            log_iters (int, optional): Display logging information at every log_iters. Default: 10.
            num_workers (int, optional): Num workers for data loader. Default: 0.
            use_vdl (bool, optional): Whether to record the data to VisualDL during training. Default: False.
            keep_checkpoint_max (int, optional): Maximum number of checkpoints to save. Default: 5.
            test_config(dict, optional): Evaluation config.
        """
        start_iter = 0
        self.model.train()
        nranks = paddle.distributed.ParallelEnv().nranks
        local_rank = paddle.distributed.ParallelEnv().local_rank

        if resume_model is not None:
            logger.info(resume_model)
            start_iter = resume(self.model, optimizer, resume_model)
        load_ema_model(self.model, self.resume_ema)

        if not os.path.isdir(save_dir):
            if os.path.exists(save_dir):
                os.remove(save_dir)
            os.makedirs(save_dir)

        if nranks > 1:
            paddle.distributed.fleet.init(is_collective=True)
            optimizer = paddle.distributed.fleet.distributed_optimizer(
                optimizer)  # The return is Fleet object
            ddp_model = paddle.distributed.fleet.distributed_model(self.model)

        batch_sampler_src = paddle.io.DistributedBatchSampler(
            train_dataset_src,
            batch_size=batch_size,
            shuffle=True,
            drop_last=True)

        loader_src = paddle.io.DataLoader(
            train_dataset_src,
            batch_sampler=batch_sampler_src,
            num_workers=num_workers,
            return_list=True,
            worker_init_fn=worker_init_fn,
        )
        batch_sampler_tgt = paddle.io.DistributedBatchSampler(
            train_dataset_tgt,
            batch_size=batch_size,
            shuffle=True,
            drop_last=True)

        loader_tgt = paddle.io.DataLoader(
            train_dataset_tgt,
            batch_sampler=batch_sampler_tgt,
            num_workers=num_workers,
            return_list=True,
            worker_init_fn=worker_init_fn,
        )

        if use_vdl:
            from visualdl import LogWriter
            log_writer = LogWriter(save_dir)

        iters_per_epoch = len(batch_sampler_tgt)
        best_mean_iou = -1.0
        best_model_iter = -1
        reader_cost_averager = TimeAverager()
        batch_cost_averager = TimeAverager()
        save_models = deque()
        batch_start = time.time()

        iter = start_iter
        while iter < iters:
            for _, (data_src,
                    data_tgt) in enumerate(zip(loader_src, loader_tgt)):

                reader_cost_averager.record(time.time() - batch_start)
                loss_dict = {}

                #### training #####
                images_tgt = data_tgt[0]
                labels_tgt = data_tgt[1].astype('int64')
                images_src = data_src[0]
                labels_src = data_src[1].astype('int64')

                edges_src = data_src[2].astype('int64')
                edges_tgt = data_tgt[2].astype('int64')

                if nranks > 1:
                    logits_list_src = ddp_model(images_src)
                else:
                    logits_list_src = self.model(images_src)

                ##### source seg & edge loss ####
                loss_src_seg_main = self.celoss(logits_list_src[0], labels_src)
                loss_src_seg_aux = 0.1 * self.celoss(logits_list_src[1],
                                                     labels_src)

                loss_src_seg = loss_src_seg_main + loss_src_seg_aux
                loss_dict["source_main"] = loss_src_seg_main.numpy()[0]
                loss_dict["source_aux"] = loss_src_seg_aux.numpy()[0]
                loss = loss_src_seg
                del loss_src_seg, loss_src_seg_aux, loss_src_seg_main

                #### generate target pseudo label  ####
                with paddle.no_grad():
                    if nranks > 1:
                        logits_list_tgt = ddp_model(images_tgt)
                    else:
                        logits_list_tgt = self.model(images_tgt)

                    pred_P_1 = F.softmax(logits_list_tgt[0], axis=1)
                    labels_tgt_psu = paddle.argmax(pred_P_1.detach(), axis=1)

                    # aux label
                    pred_P_2 = F.softmax(logits_list_tgt[1], axis=1)
                    pred_c = (pred_P_1 + pred_P_2) / 2
                    labels_tgt_psu_aux = paddle.argmax(pred_c.detach(), axis=1)

                if self.edgeconstrain:
                    loss_src_edge = self.bceloss_src(
                        logits_list_src[2], edges_src)  # 1, 2 640, 1280
                    src_edge = paddle.argmax(
                        logits_list_src[2].detach().clone(),
                        axis=1)  # 1, 1, 640,1280
                    src_edge_acc = ((src_edge == edges_src).numpy().sum().astype('float32')\
                                                /functools.reduce(lambda a, b: a * b, src_edge.shape))*100

                    if (not self.src_only) and (iter > 200000):
                        ####  target seg & edge loss ####
                        logger.info("Add target edege loss")
                        edges_tgt = Func.mask_to_binary_edge(
                            labels_tgt_psu.detach().clone().numpy(),
                            radius=2,
                            num_classes=train_dataset_tgt.NUM_CLASSES)
                        edges_tgt = paddle.to_tensor(edges_tgt, dtype='int64')

                        loss_tgt_edge = self.bceloss_tgt(
                            logits_list_tgt[2], edges_tgt)
                        loss_edge = loss_tgt_edge + loss_src_edge
                    else:
                        loss_tgt_edge = paddle.zeros([1])
                        loss_edge = loss_src_edge

                    loss += loss_edge

                    loss_dict['target_edge'] = loss_tgt_edge.numpy()[0]
                    loss_dict['source_edge'] = loss_src_edge.numpy()[0]

                    del loss_edge, loss_tgt_edge, loss_src_edge

                #### target aug loss #######
                augs = augmentation.get_augmentation()
                images_tgt_aug, labels_tgt_aug = augmentation.augment(
                    images=images_tgt.cpu(),
                    labels=labels_tgt_psu.detach().cpu(),
                    aug=augs,
                    iters="{}_1".format(iter))
                images_tgt_aug = images_tgt_aug.cuda()
                labels_tgt_aug = labels_tgt_aug.cuda()

                _, labels_tgt_aug_aux = augmentation.augment(
                    images=images_tgt.cpu(),
                    labels=labels_tgt_psu_aux.detach().cpu(),
                    aug=augs,
                    iters="{}_2".format(iter))
                labels_tgt_aug_aux = labels_tgt_aug_aux.cuda()

                if nranks > 1:
                    logits_list_tgt_aug = ddp_model(images_tgt_aug)
                else:
                    logits_list_tgt_aug = self.model(images_tgt_aug)

                loss_tgt_aug_main = 0.1 * (self.celoss(logits_list_tgt_aug[0],
                                                       labels_tgt_aug))
                loss_tgt_aug_aux = 0.1 * (0.1 * self.celoss(
                    logits_list_tgt_aug[1], labels_tgt_aug_aux))

                loss_tgt_aug = loss_tgt_aug_aux + loss_tgt_aug_main

                loss += loss_tgt_aug

                loss_dict['target_aug_main'] = loss_tgt_aug_main.numpy()[0]
                loss_dict['target_aug_aux'] = loss_tgt_aug_aux.numpy()[0]
                del images_tgt_aug, labels_tgt_aug_aux, images_tgt, \
                    loss_tgt_aug, loss_tgt_aug_aux, loss_tgt_aug_main

                #### edge input seg; src & tgt edge pull in ######
                if self.edgepullin:
                    src_edge_logit = logits_list_src[2]
                    feat_src = paddle.concat(
                        [logits_list_src[0], src_edge_logit], axis=1).detach()

                    out_src = self.model.fusion(feat_src)
                    loss_src_edge_rec = self.celoss(out_src, labels_src)

                    tgt_edge_logit = logits_list_tgt_aug[2]
                    # tgt_edge_logit = paddle.to_tensor(
                    #     Func.mask_to_onehot(edges_tgt.squeeze().numpy(), 2)
                    #     ).unsqueeze(0).astype('float32')
                    feat_tgt = paddle.concat(
                        [logits_list_tgt[0], tgt_edge_logit], axis=1).detach()

                    out_tgt = self.model.fusion(feat_tgt)
                    loss_tgt_edge_rec = self.celoss(out_tgt, labels_tgt)

                    loss_edge_rec = loss_tgt_edge_rec + loss_src_edge_rec
                    loss += loss_edge_rec

                    loss_dict['src_edge_rec'] = loss_src_edge_rec.numpy()[0]
                    loss_dict['tgt_edge_rec'] = loss_tgt_edge_rec.numpy()[0]

                    del loss_tgt_edge_rec, loss_src_edge_rec

                #### mask input feature & pullin  ######
                if self.featurepullin:
                    # inner-class loss
                    feat_src = logits_list_src[0]
                    feat_tgt = logits_list_tgt_aug[0]
                    center_src_s, center_tgt_s = [], []

                    total_pixs = logits_list_src[0].shape[2] * \
                                    logits_list_src[0].shape[3]

                    for i in range(train_dataset_tgt.NUM_CLASSES):
                        pred = paddle.argmax(
                            logits_list_src[0].detach().clone(),
                            axis=1).unsqueeze(0)  # 1, 1, 640, 1280
                        sel_num = paddle.sum((pred == i).astype('float32'))
                        # ignore tensor that do not have features in this img
                        if sel_num > 0:
                            feat_sel_src = paddle.where(
                                (pred == i).expand_as(feat_src), feat_src,
                                paddle.zeros(feat_src.shape))
                            center_src = paddle.mean(feat_sel_src, axis=[
                                2, 3
                            ]) / (sel_num / total_pixs)  # 1, C

                            self.src_centers[i] = 0.99 * self.src_centers[
                                i] + (1 - 0.99) * center_src

                        pred = labels_tgt_aug.unsqueeze(0)  # 1, 1,  512, 512
                        sel_num = paddle.sum((pred == i).astype('float32'))
                        if sel_num > 0:
                            feat_sel_tgt = paddle.where(
                                (pred == i).expand_as(feat_tgt), feat_tgt,
                                paddle.zeros(feat_tgt.shape))
                            center_tgt = paddle.mean(feat_sel_tgt, axis=[
                                2, 3
                            ]) / (sel_num / total_pixs)

                            self.tgt_centers[i] = 0.99 * self.tgt_centers[
                                i] + (1 - 0.99) * center_tgt
                        center_src_s.append(center_src)
                        center_tgt_s.append(center_tgt)

                    if iter >= 3000:  # average center structure alignment
                        src_centers = paddle.concat(self.src_centers, axis=0)
                        tgt_centers = paddle.concat(self.tgt_centers,
                                                    axis=0)  # 19, 2048

                        relatmat_src = paddle.matmul(src_centers,
                                                     src_centers,
                                                     transpose_y=True)  # 19,19
                        relatmat_tgt = paddle.matmul(tgt_centers,
                                                     tgt_centers,
                                                     transpose_y=True)

                        loss_intra_relate = self.klloss(relatmat_src, (relatmat_tgt+relatmat_src)/2) \
                                            + self.klloss(relatmat_tgt, (relatmat_tgt+relatmat_src)/2)

                        loss_pix_align_src = self.mseloss(
                            paddle.to_tensor(center_src_s),
                            paddle.to_tensor(
                                self.src_centers).detach().clone())
                        loss_pix_align_tgt = self.mseloss(
                            paddle.to_tensor(center_tgt_s),
                            paddle.to_tensor(
                                self.tgt_centers).detach().clone())

                        loss_feat_align = loss_pix_align_src + loss_pix_align_tgt + loss_intra_relate

                        loss += loss_feat_align

                        loss_dict['loss_pix_align_src'] = \
                                loss_pix_align_src.numpy()[0]
                        loss_dict['loss_pix_align_tgt'] = \
                                loss_pix_align_tgt.numpy()[0]
                        loss_dict['loss_intra_relate'] = \
                                loss_intra_relate.numpy()[0]

                        del loss_pix_align_tgt, loss_pix_align_src, loss_intra_relate,

                    self.tgt_centers = [
                        item.detach().clone() for item in self.tgt_centers
                    ]
                    self.src_centers = [
                        item.detach().clone() for item in self.src_centers
                    ]

                loss.backward()
                del loss
                loss = sum(loss_dict.values())

                optimizer.step()
                self.ema.update_params()

                with paddle.no_grad():
                    ##### log & save #####
                    lr = optimizer.get_lr()
                    # update lr
                    if isinstance(optimizer, paddle.distributed.fleet.Fleet):
                        lr_sche = optimizer.user_defined_optimizer._learning_rate
                    else:
                        lr_sche = optimizer._learning_rate
                    if isinstance(lr_sche, paddle.optimizer.lr.LRScheduler):
                        lr_sche.step()

                    if self.cfg['save_edge']:
                        tgt_edge = paddle.argmax(
                            logits_list_tgt_aug[2].detach().clone(),
                            axis=1)  # 1, 1, 640,1280
                        src_feed_gt = paddle.argmax(
                            src_edge_logit.astype('float32'), axis=1)
                        tgt_feed_gt = paddle.argmax(
                            tgt_edge_logit.astype('float32'), axis=1)
                        logger.info('src_feed_gt_{}_{}_{}'.format(
                            src_feed_gt.shape, src_feed_gt.max(),
                            src_feed_gt.min()))
                        logger.info('tgt_feed_gt_{}_{}_{}'.format(
                            tgt_feed_gt.shape, max(tgt_feed_gt),
                            min(tgt_feed_gt)))
                        save_edge(src_feed_gt, 'src_feed_gt_{}'.format(iter))
                        save_edge(tgt_feed_gt, 'tgt_feed_gt_{}'.format(iter))
                        save_edge(tgt_edge, 'tgt_pred_{}'.format(iter))
                        save_edge(src_edge,
                                  'src_pred_{}_{}'.format(iter, src_edge_acc))
                        save_edge(edges_src, 'src_gt_{}'.format(iter))
                        save_edge(edges_tgt, 'tgt_gt_{}'.format(iter))

                    self.model.clear_gradients()

                    batch_cost_averager.record(time.time() - batch_start,
                                               num_samples=batch_size)

                    iter += 1
                    if (iter) % log_iters == 0 and local_rank == 0:
                        label_tgt_acc = ((labels_tgt == labels_tgt_psu).numpy().sum().astype('float32')\
                                            /functools.reduce(lambda a, b: a * b, labels_tgt_psu.shape))*100

                        remain_iters = iters - iter
                        avg_train_batch_cost = batch_cost_averager.get_average(
                        )
                        avg_train_reader_cost = reader_cost_averager.get_average(
                        )
                        eta = calculate_eta(remain_iters, avg_train_batch_cost)
                        logger.info(
                            "[TRAIN] epoch: {}, iter: {}/{}, loss: {:.4f}, tgt_pix_acc: {:.4f}, lr: {:.6f}, batch_cost: {:.4f}, reader_cost: {:.5f}, ips: {:.4f} samples/sec | ETA {}"
                            .format(
                                (iter - 1) // iters_per_epoch + 1, iter, iters,
                                loss, label_tgt_acc, lr, avg_train_batch_cost,
                                avg_train_reader_cost,
                                batch_cost_averager.get_ips_average(), eta))

                        if use_vdl:
                            log_writer.add_scalar('Train/loss', loss, iter)
                            # Record all losses if there are more than 2 losses.
                            if len(loss_dict) > 1:
                                for name, loss in loss_dict.items():
                                    log_writer.add_scalar(
                                        'Train/loss_' + name, loss, iter)

                            log_writer.add_scalar('Train/lr', lr, iter)
                            log_writer.add_scalar('Train/batch_cost',
                                                  avg_train_batch_cost, iter)
                            log_writer.add_scalar('Train/reader_cost',
                                                  avg_train_reader_cost, iter)
                            log_writer.add_scalar('Train/tgt_label_acc',
                                                  label_tgt_acc, iter)

                        reader_cost_averager.reset()
                        batch_cost_averager.reset()

                    if (iter % save_interval == 0 or iter
                            == iters) and (val_dataset_tgt is not None):
                        num_workers = 4 if num_workers > 0 else 0  # adjust num_worker=4

                        if test_config is None:
                            test_config = {}
                        self.ema.apply_shadow()
                        self.ema.model.eval()

                        PA_tgt, _, MIoU_tgt, _ = val.evaluate(
                            self.model,
                            val_dataset_tgt,
                            num_workers=num_workers,
                            **test_config)

                        if (iter % (save_interval * 30)) == 0 \
                            and self.cfg['eval_src']:  # add evaluate on src
                            PA_src, _, MIoU_src, _ = val.evaluate(
                                self.model,
                                val_dataset_src,
                                num_workers=num_workers,
                                **test_config)
                            logger.info(
                                '[EVAL] The source mIoU is ({:.4f}) at iter {}.'
                                .format(MIoU_src, iter))

                        self.ema.restore()
                        self.model.train()

                    if (iter % save_interval == 0
                            or iter == iters) and local_rank == 0:
                        current_save_dir = os.path.join(
                            save_dir, "iter_{}".format(iter))
                        if not os.path.isdir(current_save_dir):
                            os.makedirs(current_save_dir)
                        paddle.save(
                            self.model.state_dict(),
                            os.path.join(current_save_dir, 'model.pdparams'))
                        paddle.save(
                            self.ema.shadow,
                            os.path.join(current_save_dir,
                                         'model_ema.pdparams'))
                        paddle.save(
                            optimizer.state_dict(),
                            os.path.join(current_save_dir, 'model.pdopt'))
                        save_models.append(current_save_dir)
                        if len(save_models) > keep_checkpoint_max > 0:
                            model_to_remove = save_models.popleft()
                            shutil.rmtree(model_to_remove)

                        if val_dataset_tgt is not None:
                            if MIoU_tgt > best_mean_iou:
                                best_mean_iou = MIoU_tgt
                                best_model_iter = iter
                                best_model_dir = os.path.join(
                                    save_dir, "best_model")
                                paddle.save(
                                    self.model.state_dict(),
                                    os.path.join(best_model_dir,
                                                 'model.pdparams'))

                            logger.info(
                                '[EVAL] The model with the best validation mIoU ({:.4f}) was saved at iter {}.'
                                .format(best_mean_iou, best_model_iter))

                        if use_vdl:
                            log_writer.add_scalar('Evaluate/mIoU', MIoU_tgt,
                                                  iter)
                            log_writer.add_scalar('Evaluate/PA', PA_tgt, iter)

                            if self.cfg['eval_src']:
                                log_writer.add_scalar('Evaluate/mIoU_src',
                                                      MIoU_src, iter)
                                log_writer.add_scalar('Evaluate/PA_src',
                                                      PA_src, iter)

                    batch_start = time.time()

            self.ema.update_buffer()

        # # Calculate flops.
        if local_rank == 0:

            def count_syncbn(m, x, y):
                x = x[0]
                nelements = x.numel()
                m.total_ops += int(2 * nelements)

            _, c, h, w = images_src.shape
            flops = paddle.flops(
                self.model, [1, c, h, w],
                custom_ops={paddle.nn.SyncBatchNorm: count_syncbn})

        # Sleep for half a second to let dataloader release resources.
        time.sleep(0.5)
        if use_vdl:
            log_writer.close()
Exemple #4
0
optimizer_disc = torch.optim.Adam(feature_discriminator.parameters(),
                                  lr=args.lr,
                                  betas=(args.beta1, args.beta2))

# datasets.
iterator_train = GenerateIterator(args)
iterator_val = GenerateIterator_eval(args)

# loss params.
dw = 1e-2
cw = 1
sw = 1
tw = 1e-2
bw = 1e-2
''' Exponential moving average (simulating teacher model) '''
ema = EMA(0.998)
ema.register(classifier)

# training..
for epoch in range(1, args.num_epoch):
    iterator_train.dataset.shuffledata()
    pbar = tqdm(iterator_train,
                disable=False,
                bar_format="{percentage:.0f}%,{elapsed},{remaining},{desc}")

    loss_main_sum, n_total = 0, 0
    loss_domain_sum, loss_src_class_sum, \
    loss_src_vat_sum, loss_trg_cent_sum, loss_trg_vat_sum = 0, 0, 0, 0, 0
    loss_disc_sum = 0

    for images_source, labels_source, images_target, labels_target in pbar: