Пример #1
0
def train(cfg, local_rank, distributed):
    model = build_detection_model(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(device)

    if cfg.MODEL.USE_SYNCBN:
        assert is_pytorch_1_1_0_or_later(), \
            "SyncBatchNorm is only available in pytorch >= 1.1.0"
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    optimizer = make_optimizer(cfg, model)
    scheduler = make_lr_scheduler(cfg, optimizer)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[local_rank],
            output_device=local_rank,
            # this should be removed if we update BatchNorm stats
            broadcast_buffers=False,
            # find_unused_parameters=True,
        )

    arguments = {}
    arguments["iteration"] = 0

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = get_rank() == 0
    checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler,
                                         output_dir, save_to_disk)
    extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT)
    arguments.update(extra_checkpoint_data)

    data_loader = make_data_loader(
        cfg,
        is_train=True,
        is_distributed=distributed,
        start_iter=arguments["iteration"],
    )

    checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD

    vis_period = cfg.VISUALIZE.PERIOD
    if 0 < vis_period < cfg.SOLVER.MAX_ITER:
        visualizer = SummaryWriterX(
            cfg.VISUALIZE.DIR + '/' + cfg.VISUALIZE.ENV, cfg.VISUALIZE.ENV,
            vis_period, 20, get_category(cfg.DATASETS.TRAIN[0]))
    else:
        visualizer = None

    meters = MetricLogger(delimiter="  ",
                          save_dir=os.path.join(output_dir, 'meters.json'))
    meters.load(is_main_process=get_rank() == 0)

    do_train(model, data_loader, optimizer, scheduler, checkpointer, device,
             checkpoint_period, arguments, meters, visualizer)

    return model
Пример #2
0
    def test_update(self):
        meter = MetricLogger()
        for i in range(10):
            meter.update(metric=float(i))

        m = meter.meters["metric"]
        self.assertEqual(m.count, 10)
        self.assertEqual(m.total, 45)
        self.assertEqual(m.median, 4)
        self.assertEqual(m.avg, 4.5)
Пример #3
0
    def test_no_attr(self):
        meter = MetricLogger()
        _ = meter.meters
        _ = meter.delimiter

        def broken():
            _ = meter.not_existent

        self.assertRaises(AttributeError, broken)
Пример #4
0
def validate(model, loss_fn, metric, dataloader, log_period=-1):
    logger = logging.getLogger('shaper.validate')
    meters = MetricLogger(delimiter='  ')
    metric.reset()
    meters.bind(metric)
    model.eval()
    loss_fn.eval()

    end = time.time()
    with torch.no_grad():
        for iteration, data_batch in enumerate(dataloader):
            data_time = time.time() - end

            data_batch = {
                k: v.cuda(non_blocking=True)
                for k, v in data_batch.items()
            }

            preds = model(data_batch)

            loss_dict = loss_fn(preds, data_batch)
            total_loss = sum(loss_dict.values())

            meters.update(node_acc=preds['node_acc'],
                          node_pos_acc=preds['node_pos_acc'],
                          node_neg_acc=preds['node_neg_acc'],
                          center_valid_ratio=preds['center_valid_ratio'])
            meters.update(loss=total_loss, **loss_dict)
            metric.update_dict(preds, data_batch)

            batch_time = time.time() - end
            meters.update(time=batch_time, data=data_time)
            end = time.time()

            if log_period > 0 and iteration % log_period == 0:
                logger.info(
                    meters.delimiter.join([
                        'iter: {iter:4d}',
                        '{meters}',
                    ]).format(
                        iter=iteration,
                        meters=str(meters),
                    ))
    return meters
Пример #5
0
def train(cfg, local_rank, distributed):
    logger = logging.getLogger("SDCA.trainer")

    # create network
    device = torch.device(cfg.MODEL.DEVICE)
    feature_extractor = build_feature_extractor(cfg)
    feature_extractor.to(device)

    classifier = build_classifier(cfg)
    classifier.to(device)

    if local_rank == 0:
        print(classifier)

    # batch size: half for source and half for target
    batch_size = cfg.SOLVER.BATCH_SIZE // 2
    if distributed:
        pg1 = torch.distributed.new_group(range(torch.distributed.get_world_size()))
        batch_size = int(cfg.SOLVER.BATCH_SIZE / torch.distributed.get_world_size()) // 2
        if not cfg.MODEL.FREEZE_BN:
            feature_extractor = torch.nn.SyncBatchNorm.convert_sync_batchnorm(feature_extractor)
        feature_extractor = torch.nn.parallel.DistributedDataParallel(
            feature_extractor, device_ids=[local_rank], output_device=local_rank,
            find_unused_parameters=True, process_group=pg1
        )
        pg2 = torch.distributed.new_group(range(torch.distributed.get_world_size()))
        classifier = torch.nn.parallel.DistributedDataParallel(
            classifier, device_ids=[local_rank], output_device=local_rank,
            find_unused_parameters=True, process_group=pg2
        )
        torch.autograd.set_detect_anomaly(True)
        torch.distributed.barrier()

    # init optimizer
    optimizer_fea = torch.optim.SGD(feature_extractor.parameters(), lr=cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM,
                                    weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    optimizer_fea.zero_grad()

    optimizer_cls = torch.optim.SGD(classifier.parameters(), lr=cfg.SOLVER.BASE_LR * 10, momentum=cfg.SOLVER.MOMENTUM,
                                    weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    optimizer_cls.zero_grad()

    # load checkpoint
    if cfg.resume:
        logger.info("Loading checkpoint from {}".format(cfg.resume))
        checkpoint = torch.load(cfg.resume, map_location=torch.device('cpu'))
        feature_weights = checkpoint['feature_extractor'] if distributed else strip_prefix_if_present(
            checkpoint['feature_extractor'], 'module.')
        feature_extractor.load_state_dict(feature_weights)
        classifier_weights = checkpoint['classifier'] if distributed else strip_prefix_if_present(
            checkpoint['classifier'], 'module.')
        classifier.load_state_dict(classifier_weights)

    # init data loader
    src_train_data = build_dataset(cfg, mode='train', is_source=True)
    tgt_train_data = build_dataset(cfg, mode='train', is_source=False)
    if distributed:
        src_train_sampler = torch.utils.data.distributed.DistributedSampler(src_train_data)
        tgt_train_sampler = torch.utils.data.distributed.DistributedSampler(tgt_train_data)
    else:
        src_train_sampler = None
        tgt_train_sampler = None
    src_train_loader = torch.utils.data.DataLoader(src_train_data, batch_size=batch_size,
                                                   shuffle=(src_train_sampler is None), num_workers=4,
                                                   pin_memory=True, sampler=src_train_sampler, drop_last=True)
    tgt_train_loader = torch.utils.data.DataLoader(tgt_train_data, batch_size=batch_size,
                                                   shuffle=(tgt_train_sampler is None), num_workers=4,
                                                   pin_memory=True, sampler=tgt_train_sampler, drop_last=True)

    # init loss
    ce_criterion = nn.CrossEntropyLoss(ignore_index=255)
    pcl_criterion = PixelContrastiveLoss(cfg)

    # load semantic distributions
    logger.info(">>>>>>>>>>>>>>>> Load semantic distributions >>>>>>>>>>>>>>>>")
    _, backbone_name = cfg.MODEL.NAME.split('_')
    feature_num = 2048 if backbone_name.startswith('resnet') else 1024
    feat_estimator = semantic_dist_estimator(feature_num=feature_num, cfg=cfg)
    if cfg.SOLVER.MULTI_LEVEL:
        out_estimator = semantic_dist_estimator(feature_num=cfg.MODEL.NUM_CLASSES, cfg=cfg)

    iteration = 0
    start_training_time = time.time()
    end = time.time()
    save_to_disk = local_rank == 0
    max_iters = cfg.SOLVER.MAX_ITER
    meters = MetricLogger(delimiter="  ")

    logger.info(">>>>>>>>>>>>>>>> Start Training >>>>>>>>>>>>>>>>")
    feature_extractor.train()
    classifier.train()

    for i, ((src_input, src_label, src_name), (tgt_input, _, _)) in enumerate(zip(src_train_loader, tgt_train_loader)):

        data_time = time.time() - end

        current_lr = adjust_learning_rate(cfg.SOLVER.LR_METHOD, cfg.SOLVER.BASE_LR, iteration, max_iters,
                                          power=cfg.SOLVER.LR_POWER)
        for index in range(len(optimizer_fea.param_groups)):
            optimizer_fea.param_groups[index]['lr'] = current_lr
        for index in range(len(optimizer_cls.param_groups)):
            optimizer_cls.param_groups[index]['lr'] = current_lr * 10

        optimizer_fea.zero_grad()
        optimizer_cls.zero_grad()
        src_input = src_input.cuda(non_blocking=True)
        src_label = src_label.cuda(non_blocking=True).long()
        tgt_input = tgt_input.cuda(non_blocking=True)

        src_size = src_input.shape[-2:]
        src_feat = feature_extractor(src_input)
        src_out = classifier(src_feat)
        tgt_feat = feature_extractor(tgt_input)
        tgt_out = classifier(tgt_feat)
        tgt_out_softmax = F.softmax(tgt_out, dim=1)

        # supervision loss
        src_pred = F.interpolate(src_out, size=src_size, mode='bilinear', align_corners=True)
        if cfg.SOLVER.LAMBDA_LOV > 0:
            src_pred_softmax = F.softmax(src_pred, dim=1)
            loss_lov = lovasz_softmax(src_pred_softmax, src_label, ignore=255)
            loss_sup = ce_criterion(src_pred, src_label) + cfg.SOLVER.LAMBDA_LOV * loss_lov
            meters.update(loss_lov=loss_lov.item())
        else:
            loss_sup = ce_criterion(src_pred, src_label)
        meters.update(loss_sup=loss_sup.item())

        # source mask: downsample the ground-truth label
        B, A, Hs, Ws = src_feat.size()
        src_mask = F.interpolate(src_label.unsqueeze(0).float(), size=(Hs, Ws), mode='nearest').squeeze(0).long()
        src_mask = src_mask.contiguous().view(B * Hs * Ws, )
        assert not src_mask.requires_grad
        # target mask: constant threshold
        _, _, Ht, Wt = tgt_feat.size()
        tgt_out_maxvalue, tgt_mask = torch.max(tgt_out_softmax, dim=1)
        for i in range(cfg.MODEL.NUM_CLASSES):
            tgt_mask[(tgt_out_maxvalue < cfg.SOLVER.DELTA) * (tgt_mask == i)] = 255
        tgt_mask = tgt_mask.contiguous().view(B * Ht * Wt, )
        assert not tgt_mask.requires_grad

        src_feat = src_feat.permute(0, 2, 3, 1).contiguous().view(B * Hs * Ws, A)
        tgt_feat = tgt_feat.permute(0, 2, 3, 1).contiguous().view(B * Ht * Wt, A)
        # update feature-level statistics
        feat_estimator.update(features=src_feat.detach(), labels=src_mask)

        # contrastive loss on both domains
        loss_feat = pcl_criterion(Mean=feat_estimator.Mean.detach(),
                                  CoVariance=feat_estimator.CoVariance.detach(),
                                  feat=src_feat,
                                  labels=src_mask) \
                    + pcl_criterion(Mean=feat_estimator.Mean.detach(),
                                  CoVariance=feat_estimator.CoVariance.detach(),
                                  feat=tgt_feat,
                                  labels=tgt_mask)
        meters.update(loss_feat=loss_feat.item())

        if cfg.SOLVER.MULTI_LEVEL:
            src_out = src_out.permute(0, 2, 3, 1).contiguous().view(B * Hs * Ws, cfg.MODEL.NUM_CLASSES)
            tgt_out = tgt_out.permute(0, 2, 3, 1).contiguous().view(B * Ht * Wt, cfg.MODEL.NUM_CLASSES)

            # update output-level statistics
            out_estimator.update(features=src_out.detach(), labels=src_mask)

            # the proposed contrastive loss on prediction map
            loss_out = pcl_criterion(Mean=out_estimator.Mean.detach(),
                                     CoVariance=out_estimator.CoVariance.detach(),
                                     feat=src_out,
                                     labels=src_mask) \
                       + pcl_criterion(Mean=out_estimator.Mean.detach(),
                                       CoVariance=out_estimator.CoVariance.detach(),
                                       feat=tgt_out,
                                       labels=tgt_mask)
            meters.update(loss_out=loss_out.item())

            loss = loss_sup \
                   + cfg.SOLVER.LAMBDA_FEAT * loss_feat \
                   + cfg.SOLVER.LAMBDA_OUT * loss_out
        else:
            loss = loss_sup + cfg.SOLVER.LAMBDA_FEAT * loss_feat

        loss.backward()

        optimizer_fea.step()
        optimizer_cls.step()

        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time, data=data_time)

        eta_seconds = meters.time.global_avg * (cfg.SOLVER.STOP_ITER - iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

        iteration += 1
        if iteration % 20 == 0 or iteration == max_iters:
            logger.info(
                meters.delimiter.join(
                    [
                        "eta: {eta}",
                        "iter: {iter}",
                        "{meters}",
                        "lr: {lr:.6f}",
                        "max mem: {memory:.02f} GB"
                    ]
                ).format(
                    eta=eta_string,
                    iter=iteration,
                    meters=str(meters),
                    lr=optimizer_fea.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0 / 1024.0
                )
            )

        if (iteration == cfg.SOLVER.MAX_ITER or iteration % cfg.SOLVER.CHECKPOINT_PERIOD == 0) and save_to_disk:
            filename = os.path.join(cfg.OUTPUT_DIR, "model_iter{:06d}.pth".format(iteration))
            torch.save({'iteration': iteration,
                        'feature_extractor': feature_extractor.state_dict(),
                        'classifier': classifier.state_dict(),
                        'optimizer_fea': optimizer_fea.state_dict(),
                        'optimizer_cls': optimizer_cls.state_dict(),
                        }, filename)

        if iteration == cfg.SOLVER.MAX_ITER:
            break
        if iteration == cfg.SOLVER.STOP_ITER:
            break

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info(
        "Total training time: {} ({:.4f} s / it)".format(
            total_time_str, total_training_time / cfg.SOLVER.STOP_ITER
        )
    )

    return feature_extractor, classifier
Пример #6
0
def test(cfg, output_dir='', output_dir_merge='', output_dir_save=''):
    logger = logging.getLogger('shaper.test')

    # build model
    model, loss_fn, _, val_metric = build_model(cfg)
    model = nn.DataParallel(model).cuda()
    model_merge = nn.DataParallel(PointNetCls(in_channels=3,
                                              out_channels=128)).cuda()

    # build checkpointer
    checkpointer = Checkpointer(model, save_dir=output_dir, logger=logger)
    checkpointer_merge = Checkpointer(model_merge,
                                      save_dir=output_dir_merge,
                                      logger=logger)

    if cfg.TEST.WEIGHT:
        # load weight if specified
        weight_path = cfg.TEST.WEIGHT.replace('@', output_dir)
        checkpointer.load(weight_path, resume=False)
    else:
        # load last checkpoint
        checkpointer.load(None, resume=True)
        checkpointer_merge.load(None, resume=True)
        #checkpointer_refine.load(None, resume=True)

    # build data loader
    test_dataloader = build_dataloader(cfg, mode='test')
    test_dataset = test_dataloader.dataset

    assert cfg.TEST.BATCH_SIZE == 1, '{} != 1'.format(cfg.TEST.BATCH_SIZE)
    save_fig_dir = osp.join(output_dir_save, 'test_fig')
    os.makedirs(save_fig_dir, exist_ok=True)
    save_fig_dir_size = osp.join(save_fig_dir, 'size')
    os.makedirs(save_fig_dir_size, exist_ok=True)
    save_fig_dir_gt = osp.join(save_fig_dir, 'gt')
    os.makedirs(save_fig_dir_gt, exist_ok=True)

    # ---------------------------------------------------------------------------- #
    # Test
    # ---------------------------------------------------------------------------- #
    model.eval()
    model_merge.eval()
    loss_fn.eval()
    softmax = nn.Softmax()
    set_random_seed(cfg.RNG_SEED)

    NUM_POINT = 10000
    n_shape = len(test_dataloader)
    NUM_INS = 200
    out_mask = np.zeros((n_shape, NUM_INS, NUM_POINT), dtype=np.bool)
    out_valid = np.zeros((n_shape, NUM_INS), dtype=np.bool)
    out_conf = np.ones((n_shape, NUM_INS), dtype=np.float32)

    meters = MetricLogger(delimiter='  ')
    meters.bind(val_metric)
    tot_purity_error_list = list()
    tot_purity_error_small_list = list()
    tot_purity_error_large_list = list()
    tot_pred_acc = list()
    tot_pred_small_acc = list()
    tot_pred_large_acc = list()
    tot_mean_rela_size_list = list()
    tot_mean_policy_label0 = list()
    tot_mean_label_policy0 = list()
    tot_mean_policy_label0_large = list()
    tot_mean_policy_label0_small = list()
    tot_mean_label_policy0_large = list()
    tot_mean_label_policy0_small = list()
    with torch.no_grad():
        start_time = time.time()
        end = start_time
        for iteration, data_batch in enumerate(test_dataloader):
            print(iteration)

            data_time = time.time() - end
            iter_start_time = time.time()

            data_batch = {
                k: v.cuda(non_blocking=True)
                for k, v in data_batch.items()
            }

            preds = model(data_batch)
            loss_dict = loss_fn(preds, data_batch)
            meters.update(**loss_dict)
            val_metric.update_dict(preds, data_batch)

            #extraction box features
            batch_size, _, num_centroids, num_neighbours = data_batch[
                'neighbour_xyz'].shape
            num_points = data_batch['points'].shape[-1]

            #batch_size, num_centroid, num_neighbor
            _, p = torch.max(preds['ins_logit'], 1)
            box_index_expand = torch.zeros(
                (batch_size * num_centroids, num_points)).cuda()
            box_index_expand = box_index_expand.scatter_(
                dim=1,
                index=data_batch['neighbour_index'].reshape(
                    [-1, num_neighbours]),
                src=p.reshape([-1, num_neighbours]).float())
            #centroid_label = data_batch['centroid_label'].reshape(-1)

            minimum_box_pc_num = 16
            minimum_overlap_pc_num = 16  #1/16 * num_neighbour
            gtmin_mask = (torch.sum(box_index_expand, dim=-1) >
                          minimum_box_pc_num)

            #remove purity < 0.8
            box_label_expand = torch.zeros(
                (batch_size * num_centroids, 200)).cuda()
            purity_pred = torch.zeros([0]).type(torch.LongTensor).cuda()
            purity_pred_float = torch.zeros([0]).type(torch.FloatTensor).cuda()

            for i in range(batch_size):
                cur_xyz_pool, xyz_mean = mask_to_xyz(
                    data_batch['points'][i],
                    box_index_expand.view(batch_size, num_centroids,
                                          num_points)[i],
                    sample_num=512)
                cur_xyz_pool -= xyz_mean
                cur_xyz_pool /= (cur_xyz_pool + 1e-6).norm(dim=1).max(
                    dim=-1)[0].unsqueeze(-1).unsqueeze(-1)

                logits_purity = model_merge(cur_xyz_pool, 'purity')
                p = (logits_purity > 0.8).long().squeeze()
                purity_pred = torch.cat([purity_pred, p])
                purity_pred_float = torch.cat(
                    [purity_pred_float,
                     logits_purity.squeeze()])

            p_thresh = 0.8
            purity_pred = purity_pred_float > p_thresh
            #in case remove too much
            while (torch.sum(purity_pred) < 48):
                p_thresh = p_thresh - 0.01
                purity_pred = purity_pred_float > p_thresh
            valid_mask = gtmin_mask.long() * purity_pred.long()
            box_index_expand = torch.index_select(
                box_index_expand, dim=0, index=valid_mask.nonzero().squeeze())

            box_num = torch.sum(valid_mask.reshape(batch_size, num_centroids),
                                1)
            cumsum_box_num = torch.cumsum(box_num, dim=0)
            cumsum_box_num = torch.cat([
                torch.from_numpy(np.array(0)).cuda().unsqueeze(0),
                cumsum_box_num
            ],
                                       dim=0)

            with torch.no_grad():
                pc_all = data_batch['points']
                xyz_pool1 = torch.zeros([0, 3, 1024]).float().cuda()
                xyz_pool2 = torch.zeros([0, 3, 1024]).float().cuda()
                label_pool = torch.zeros([0]).float().cuda()
                for i in range(pc_all.shape[0]):
                    bs = 1
                    pc = pc_all[i].clone()
                    cur_mask_pool = box_index_expand[
                        cumsum_box_num[i]:cumsum_box_num[i + 1]].clone()
                    cover_ratio = torch.unique(
                        cur_mask_pool.nonzero()[:, 1]).shape[0] / num_points
                    #print(iteration, cover_ratio)
                    cur_xyz_pool, xyz_mean = mask_to_xyz(pc, cur_mask_pool)
                    subpart_pool = cur_xyz_pool.clone()
                    subpart_mask_pool = cur_mask_pool.clone()
                    init_pool_size = cur_xyz_pool.shape[0]
                    meters.update(cover_ratio=cover_ratio,
                                  init_pool_size=init_pool_size)
                    negative_num = 0
                    positive_num = 0

                    #remove I
                    inter_matrix = torch.matmul(cur_mask_pool,
                                                cur_mask_pool.transpose(0, 1))
                    inter_matrix_full = inter_matrix.clone(
                    ) > minimum_overlap_pc_num
                    inter_matrix[torch.eye(inter_matrix.shape[0]).byte()] = 0
                    pair_idx = (inter_matrix.triu() >
                                minimum_overlap_pc_num).nonzero()
                    zero_pair = torch.ones([0, 2]).long()
                    purity_matrix = torch.zeros(inter_matrix.shape).cuda()
                    policy_matrix = torch.zeros(inter_matrix.shape).cuda()
                    bsp = 64
                    idx = torch.arange(pair_idx.shape[0]).cuda()
                    #calculate initial policy score matrix
                    purity_pool = torch.zeros([0]).float().cuda()
                    policy_pool = torch.zeros([0]).float().cuda()
                    for k in range(int(np.ceil(idx.shape[0] / bsp))):
                        sub_part_idx = torch.index_select(
                            pair_idx, dim=0, index=idx[k * bsp:(k + 1) * bsp])
                        part_xyz1 = torch.index_select(cur_xyz_pool,
                                                       dim=0,
                                                       index=sub_part_idx[:,
                                                                          0])
                        part_xyz2 = torch.index_select(cur_xyz_pool,
                                                       dim=0,
                                                       index=sub_part_idx[:,
                                                                          1])
                        part_xyz = torch.cat([part_xyz1, part_xyz2], -1)
                        part_xyz -= torch.mean(part_xyz, -1).unsqueeze(-1)
                        part_norm = part_xyz.norm(dim=1).max(
                            dim=-1)[0].unsqueeze(-1).unsqueeze(-1)
                        part_xyz /= part_norm
                        logits_purity = model_merge(part_xyz,
                                                    'purity').squeeze()
                        if len(logits_purity.shape) == 0:
                            logits_purity = logits_purity.unsqueeze(0)
                        purity_pool = torch.cat([purity_pool, logits_purity],
                                                dim=0)

                        part_xyz11 = part_xyz1 - torch.mean(part_xyz1,
                                                            -1).unsqueeze(-1)
                        part_xyz22 = part_xyz2 - torch.mean(part_xyz2,
                                                            -1).unsqueeze(-1)
                        part_xyz11 /= part_norm
                        part_xyz22 /= part_norm
                        logits11 = model_merge(part_xyz11, 'policy')
                        logits22 = model_merge(part_xyz22, 'policy')
                        policy_scores = model_merge(
                            torch.cat([logits11, logits22], dim=-1),
                            'policy_head').squeeze()
                        if len(policy_scores.shape) == 0:
                            policy_scores = policy_scores.unsqueeze(0)
                        policy_pool = torch.cat([policy_pool, policy_scores],
                                                dim=0)

                    purity_matrix[pair_idx[:, 0], pair_idx[:, 1]] = purity_pool
                    policy_matrix[pair_idx[:, 0], pair_idx[:, 1]] = policy_pool
                    score_matrix = torch.zeros(purity_matrix.shape).cuda()
                    score_matrix[pair_idx[:, 0], pair_idx[:, 1]] = softmax(
                        purity_pool * policy_pool)
                    meters.update(initial_pair_num=pair_idx.shape[0])
                    iteration_num = 0
                    remote_flag = False

                    #info
                    policy_list = []
                    purity_list = []
                    gt_purity_list = []
                    gt_label_list = []
                    pred_label_list = []
                    size_list = []
                    relative_size_list = []

                    while (pair_idx.shape[0] > 0) or (remote_flag == False):
                        if pair_idx.shape[0] == 0:
                            remote_flag = True
                            inter_matrix = 20 * torch.ones([
                                cur_mask_pool.shape[0], cur_mask_pool.shape[0]
                            ]).cuda()
                            inter_matrix[zero_pair[:, 0], zero_pair[:, 1]] = 0
                            inter_matrix[torch.eye(
                                inter_matrix.shape[0]).byte()] = 0
                            pair_idx = (inter_matrix.triu() >
                                        minimum_overlap_pc_num).nonzero()
                            if pair_idx.shape[0] == 0:
                                break
                            purity_matrix = torch.zeros(
                                inter_matrix.shape).cuda()
                            policy_matrix = torch.zeros(
                                inter_matrix.shape).cuda()
                            bsp = 64
                            idx = torch.arange(pair_idx.shape[0]).cuda()
                            purity_pool = torch.zeros([0]).float().cuda()
                            policy_pool = torch.zeros([0]).float().cuda()
                            for k in range(int(np.ceil(idx.shape[0] / bsp))):
                                sub_part_idx = torch.index_select(
                                    pair_idx,
                                    dim=0,
                                    index=idx[k * bsp:(k + 1) * bsp])
                                part_xyz1 = torch.index_select(
                                    cur_xyz_pool,
                                    dim=0,
                                    index=sub_part_idx[:, 0])
                                part_xyz2 = torch.index_select(
                                    cur_xyz_pool,
                                    dim=0,
                                    index=sub_part_idx[:, 1])
                                part_xyz = torch.cat([part_xyz1, part_xyz2],
                                                     -1)
                                part_xyz -= torch.mean(part_xyz,
                                                       -1).unsqueeze(-1)
                                part_norm = part_xyz.norm(dim=1).max(
                                    dim=-1)[0].unsqueeze(-1).unsqueeze(-1)
                                part_xyz /= part_norm
                                logits_purity = model_merge(
                                    part_xyz, 'purity').squeeze()
                                if len(logits_purity.shape) == 0:
                                    logits_purity = logits_purity.unsqueeze(0)
                                purity_pool = torch.cat(
                                    [purity_pool, logits_purity], dim=0)

                                part_xyz11 = part_xyz1 - torch.mean(
                                    part_xyz1, -1).unsqueeze(-1)
                                part_xyz22 = part_xyz2 - torch.mean(
                                    part_xyz2, -1).unsqueeze(-1)
                                part_xyz11 /= part_norm
                                part_xyz22 /= part_norm
                                logits11 = model_merge(part_xyz11, 'policy')
                                logits22 = model_merge(part_xyz22, 'policy')
                                policy_scores = model_merge(
                                    torch.cat([logits11, logits22], dim=-1),
                                    'policy_head').squeeze()
                                if len(policy_scores.shape) == 0:
                                    policy_scores = policy_scores.unsqueeze(0)
                                policy_pool = torch.cat(
                                    [policy_pool, policy_scores], dim=0)
                            purity_matrix[pair_idx[:, 0],
                                          pair_idx[:, 1]] = purity_pool
                            policy_matrix[pair_idx[:, 0],
                                          pair_idx[:, 1]] = policy_pool
                            score_matrix = torch.zeros(
                                purity_matrix.shape).cuda()
                            score_matrix[pair_idx[:, 0],
                                         pair_idx[:, 1]] = softmax(
                                             purity_pool * policy_pool)
                        iteration_num += 1

                        #everytime select the pair with highest score
                        score_arr = score_matrix[pair_idx[:, 0], pair_idx[:,
                                                                          1]]
                        highest_score, rank_idx = torch.topk(score_arr,
                                                             1,
                                                             largest=True,
                                                             sorted=False)
                        perm_idx = rank_idx
                        assert highest_score == score_matrix[pair_idx[rank_idx,
                                                                      0],
                                                             pair_idx[rank_idx,
                                                                      1]]

                        sub_part_idx = torch.index_select(pair_idx,
                                                          dim=0,
                                                          index=perm_idx[:bs])
                        purity_score = purity_matrix[sub_part_idx[:, 0],
                                                     sub_part_idx[:, 1]]
                        policy_score = policy_matrix[sub_part_idx[:, 0],
                                                     sub_part_idx[:, 1]]

                        #info
                        policy_list.append(policy_score.cpu().data.numpy()[0])
                        purity_list.append(purity_score.cpu().data.numpy()[0])

                        part_xyz1 = torch.index_select(cur_xyz_pool,
                                                       dim=0,
                                                       index=sub_part_idx[:,
                                                                          0])
                        part_xyz2 = torch.index_select(cur_xyz_pool,
                                                       dim=0,
                                                       index=sub_part_idx[:,
                                                                          1])
                        part_xyz = torch.cat([part_xyz1, part_xyz2], -1)
                        part_xyz -= torch.mean(part_xyz, -1).unsqueeze(-1)
                        part_xyz1 -= torch.mean(part_xyz1, -1).unsqueeze(-1)
                        part_xyz2 -= torch.mean(part_xyz2, -1).unsqueeze(-1)
                        part_xyz1 /= part_xyz1.norm(dim=1).max(
                            dim=-1)[0].unsqueeze(-1).unsqueeze(-1)
                        part_xyz2 /= part_xyz2.norm(dim=1).max(
                            dim=-1)[0].unsqueeze(-1).unsqueeze(-1)
                        part_xyz /= part_xyz.norm(dim=1).max(
                            dim=-1)[0].unsqueeze(-1).unsqueeze(-1)
                        part_mask11 = torch.index_select(cur_mask_pool,
                                                         dim=0,
                                                         index=sub_part_idx[:,
                                                                            0])
                        part_mask22 = torch.index_select(cur_mask_pool,
                                                         dim=0,
                                                         index=sub_part_idx[:,
                                                                            1])
                        context_idx1 = torch.index_select(
                            inter_matrix_full, dim=0, index=sub_part_idx[:, 0])
                        context_idx2 = torch.index_select(
                            inter_matrix_full, dim=0, index=sub_part_idx[:, 1])
                        context_mask1 = (torch.matmul(
                            context_idx1.float(), cur_mask_pool) > 0).float()
                        context_mask2 = (torch.matmul(
                            context_idx2.float(), cur_mask_pool) > 0).float()
                        context_mask = ((context_mask1 + context_mask2) >
                                        0).float()
                        context_xyz, xyz_mean = mask_to_xyz(pc,
                                                            context_mask,
                                                            sample_num=2048)
                        context_xyz = context_xyz - xyz_mean
                        context_xyz /= context_xyz.norm(dim=1).max(
                            dim=-1)[0].unsqueeze(-1).unsqueeze(-1)

                        if cfg.DATASET.PartNetInsSeg.TEST.shape not in [
                                'Chair', 'Lamp', 'StorageFurniture'
                        ]:
                            logits1 = model_merge(part_xyz1, 'backbone')
                            logits2 = model_merge(part_xyz2, 'backbone')
                            merge_logits = model_merge(
                                torch.cat([
                                    part_xyz,
                                    torch.cat([
                                        logits1.unsqueeze(-1).expand(
                                            -1, -1, part_xyz1.shape[-1]),
                                        logits2.unsqueeze(-1).expand(
                                            -1, -1, part_xyz2.shape[-1])
                                    ],
                                              dim=-1)
                                ],
                                          dim=1), 'head')
                        else:
                            if (cur_xyz_pool.shape[0] >= 32):
                                logits1 = model_merge(part_xyz1, 'backbone')
                                logits2 = model_merge(part_xyz2, 'backbone')
                                merge_logits = model_merge(
                                    torch.cat([
                                        part_xyz,
                                        torch.cat([
                                            logits1.unsqueeze(-1).expand(
                                                -1, -1, part_xyz1.shape[-1]),
                                            logits2.unsqueeze(-1).expand(
                                                -1, -1, part_xyz2.shape[-1])
                                        ],
                                                  dim=-1)
                                    ],
                                              dim=1), 'head')
                            else:
                                logits1 = model_merge(part_xyz1, 'backbone')
                                logits2 = model_merge(part_xyz2, 'backbone')
                                context_logits = model_merge(
                                    context_xyz, 'backbone2')
                                merge_logits = model_merge(
                                    torch.cat([
                                        part_xyz,
                                        torch.cat([
                                            logits1.unsqueeze(-1).expand(
                                                -1, -1, part_xyz1.shape[-1]),
                                            logits2.unsqueeze(-1).expand(
                                                -1, -1, part_xyz2.shape[-1])
                                        ],
                                                  dim=-1),
                                        torch.cat([
                                            context_logits.unsqueeze(-1).
                                            expand(-1, -1, part_xyz.shape[-1])
                                        ],
                                                  dim=-1)
                                    ],
                                              dim=1), 'head2')

                        _, p = torch.max(merge_logits, 1)
                        if not remote_flag:
                            siamese_label = p * (
                                (purity_score > p_thresh).long())
                        else:
                            siamese_label = p
                        siamese_label = p * ((purity_score > p_thresh).long())
                        negative_num += torch.sum(siamese_label == 0)
                        positive_num += torch.sum(siamese_label == 1)
                        pred_label_list.append(
                            siamese_label.cpu().data.numpy())

                        #info
                        new_part_mask = 1 - (1 - part_mask11) * (1 -
                                                                 part_mask22)
                        size_list.append(
                            torch.sum(new_part_mask).cpu().data.numpy())
                        size1 = torch.sum(part_mask11).cpu().data.numpy()
                        size2 = torch.sum(part_mask22).cpu().data.numpy()
                        relative_size_list.append(size1 / size2 +
                                                  size2 / size1)

                        #update info
                        merge_idx1 = torch.index_select(
                            sub_part_idx[:, 0],
                            dim=0,
                            index=siamese_label.nonzero().squeeze())
                        merge_idx2 = torch.index_select(
                            sub_part_idx[:, 1],
                            dim=0,
                            index=siamese_label.nonzero().squeeze())
                        merge_idx = torch.unique(
                            torch.cat([merge_idx1, merge_idx2], dim=0))
                        nonmerge_idx1 = torch.index_select(
                            sub_part_idx[:, 0],
                            dim=0,
                            index=(1 - siamese_label).nonzero().squeeze())
                        nonmerge_idx2 = torch.index_select(
                            sub_part_idx[:, 1],
                            dim=0,
                            index=(1 - siamese_label).nonzero().squeeze())
                        part_mask1 = torch.index_select(cur_mask_pool,
                                                        dim=0,
                                                        index=merge_idx1)
                        part_mask2 = torch.index_select(cur_mask_pool,
                                                        dim=0,
                                                        index=merge_idx2)
                        new_part_mask = 1 - (1 - part_mask1) * (1 - part_mask2)

                        equal_matrix = torch.matmul(
                            new_part_mask,
                            1 - new_part_mask.transpose(0, 1)) + torch.matmul(
                                1 - new_part_mask, new_part_mask.transpose(
                                    0, 1))
                        equal_matrix[torch.eye(
                            equal_matrix.shape[0]).byte()] = 1
                        fid = (equal_matrix == 0).nonzero()
                        if fid.shape[0] > 0:
                            flag = torch.ones(merge_idx1.shape[0])
                            for k in range(flag.shape[0]):
                                if flag[k] != 0:
                                    flag[fid[:, 1][fid[:, 0] == k]] = 0
                            new_part_mask = torch.index_select(
                                new_part_mask,
                                dim=0,
                                index=flag.nonzero().squeeze().cuda())

                        new_part_xyz, xyz_mean = mask_to_xyz(pc, new_part_mask)

                        #update purity and score, policy score matrix
                        if new_part_mask.shape[0] > 0:
                            overlap_idx = (
                                torch.matmul(cur_mask_pool,
                                             new_part_mask.transpose(0, 1)) >
                                minimum_overlap_pc_num).nonzero().squeeze()
                            if overlap_idx.shape[0] > 0:
                                if len(overlap_idx.shape) == 1:
                                    overlap_idx = overlap_idx.unsqueeze(0)
                                part_xyz1 = torch.index_select(
                                    cur_xyz_pool,
                                    dim=0,
                                    index=overlap_idx[:, 0])
                                part_xyz2 = tile(new_part_xyz, 0,
                                                 overlap_idx.shape[0])
                                part_xyz = torch.cat([part_xyz1, part_xyz2],
                                                     -1)
                                part_xyz -= torch.mean(part_xyz,
                                                       -1).unsqueeze(-1)
                                part_norm = part_xyz.norm(dim=1).max(
                                    dim=-1)[0].unsqueeze(-1).unsqueeze(-1)
                                part_xyz /= part_norm
                                overlap_purity_scores = model_merge(
                                    part_xyz, 'purity').squeeze()

                                part_xyz11 = part_xyz1 - torch.mean(
                                    part_xyz1, -1).unsqueeze(-1)
                                part_xyz22 = part_xyz2 - torch.mean(
                                    part_xyz2, -1).unsqueeze(-1)
                                part_xyz11 /= part_norm
                                part_xyz22 /= part_norm
                                logits11 = model_merge(part_xyz11, 'policy')
                                logits22 = model_merge(part_xyz22, 'policy')
                                overlap_policy_scores = model_merge(
                                    torch.cat([logits11, logits22], dim=-1),
                                    'policy_head').squeeze()

                                tmp_purity_arr = torch.zeros(
                                    [purity_matrix.shape[0]]).cuda()
                                tmp_policy_arr = torch.zeros(
                                    [policy_matrix.shape[0]]).cuda()
                                tmp_purity_arr[
                                    overlap_idx[:, 0]] = overlap_purity_scores
                                tmp_policy_arr[
                                    overlap_idx[:, 0]] = overlap_policy_scores
                                purity_matrix = torch.cat([
                                    purity_matrix,
                                    tmp_purity_arr.unsqueeze(1)
                                ],
                                                          dim=1)
                                policy_matrix = torch.cat([
                                    policy_matrix,
                                    tmp_policy_arr.unsqueeze(1)
                                ],
                                                          dim=1)
                                purity_matrix = torch.cat([
                                    purity_matrix,
                                    torch.zeros(purity_matrix.shape[1]).cuda().
                                    unsqueeze(0)
                                ])
                                policy_matrix = torch.cat([
                                    policy_matrix,
                                    torch.zeros(policy_matrix.shape[1]).cuda().
                                    unsqueeze(0)
                                ])
                            else:
                                purity_matrix = torch.cat([
                                    purity_matrix,
                                    torch.zeros(purity_matrix.shape[0]).cuda().
                                    unsqueeze(1)
                                ],
                                                          dim=1)
                                policy_matrix = torch.cat([
                                    policy_matrix,
                                    torch.zeros(policy_matrix.shape[0]).cuda().
                                    unsqueeze(1)
                                ],
                                                          dim=1)
                                purity_matrix = torch.cat([
                                    purity_matrix,
                                    torch.zeros(purity_matrix.shape[1]).cuda().
                                    unsqueeze(0)
                                ])
                                policy_matrix = torch.cat([
                                    policy_matrix,
                                    torch.zeros(policy_matrix.shape[1]).cuda().
                                    unsqueeze(0)
                                ])

                        cur_mask_pool = torch.cat(
                            [cur_mask_pool, new_part_mask], dim=0)
                        subpart_mask_pool = torch.cat(
                            [subpart_mask_pool, new_part_mask], dim=0)
                        cur_xyz_pool = torch.cat([cur_xyz_pool, new_part_xyz],
                                                 dim=0)
                        subpart_pool = torch.cat([subpart_pool, new_part_xyz],
                                                 dim=0)
                        cur_pool_size = cur_mask_pool.shape[0]
                        new_mask = torch.ones([cur_pool_size])
                        new_mask[merge_idx] = 0
                        new_idx = new_mask.nonzero().squeeze().cuda()
                        cur_xyz_pool = torch.index_select(cur_xyz_pool,
                                                          dim=0,
                                                          index=new_idx)
                        cur_mask_pool = torch.index_select(cur_mask_pool,
                                                           dim=0,
                                                           index=new_idx)
                        inter_matrix = torch.matmul(
                            cur_mask_pool, cur_mask_pool.transpose(0, 1))
                        inter_matrix_full = inter_matrix.clone(
                        ) > minimum_overlap_pc_num
                        if remote_flag:
                            inter_matrix = 20 * torch.ones([
                                cur_mask_pool.shape[0], cur_mask_pool.shape[0]
                            ]).cuda()
                        #update zero_matrix
                        zero_matrix = torch.zeros(
                            [cur_pool_size, cur_pool_size])
                        zero_matrix[zero_pair[:, 0], zero_pair[:, 1]] = 1
                        zero_matrix[nonmerge_idx1, nonmerge_idx2] = 1
                        zero_matrix[nonmerge_idx2, nonmerge_idx1] = 1
                        zero_matrix = torch.index_select(zero_matrix,
                                                         dim=0,
                                                         index=new_idx.cpu())
                        zero_matrix = torch.index_select(zero_matrix,
                                                         dim=1,
                                                         index=new_idx.cpu())
                        zero_pair = zero_matrix.nonzero()
                        inter_matrix[zero_pair[:, 0], zero_pair[:, 1]] = 0
                        inter_matrix[torch.eye(
                            inter_matrix.shape[0]).byte()] = 0
                        pair_idx = (inter_matrix.triu() >
                                    minimum_overlap_pc_num).nonzero()

                        purity_matrix = torch.index_select(purity_matrix,
                                                           dim=0,
                                                           index=new_idx)
                        purity_matrix = torch.index_select(purity_matrix,
                                                           dim=1,
                                                           index=new_idx)
                        policy_matrix = torch.index_select(policy_matrix,
                                                           dim=0,
                                                           index=new_idx)
                        policy_matrix = torch.index_select(policy_matrix,
                                                           dim=1,
                                                           index=new_idx)
                        score_matrix = torch.zeros(purity_matrix.shape).cuda()
                        score_idx = pair_idx
                        score_matrix[score_idx[:, 0],
                                     score_idx[:, 1]] = softmax(
                                         purity_matrix[score_idx[:, 0],
                                                       score_idx[:, 1]] *
                                         policy_matrix[score_idx[:, 0],
                                                       score_idx[:, 1]])
                    final_pool_size = subpart_pool.shape[0]
                    meters.update(final_pool_size=final_pool_size,
                                  negative_num=negative_num,
                                  positive_num=positive_num)
                    meters.update(iteration_num=iteration_num)
                    meters.update(iteration_time=time.time() - iter_start_time)

            t1 = torch.matmul(cur_mask_pool, 1 - cur_mask_pool.transpose(0, 1))
            t1[torch.eye(t1.shape[0]).byte()] = 1
            t1_id = (t1 == 0).nonzero()
            final_idx = torch.ones(t1.shape[0])
            final_idx[t1_id[:, 0]] = 0
            cur_mask_pool = torch.index_select(
                cur_mask_pool,
                dim=0,
                index=final_idx.nonzero().squeeze().cuda())

            pred_ins_label = torch.zeros(num_points).cuda()
            for k in range(cur_mask_pool.shape[0]):
                pred_ins_label[cur_mask_pool[k].byte()] = k + 1
            valid_idx = torch.sum(cur_mask_pool, 0) > 0
            if torch.sum(1 - valid_idx) != 0:
                valid_points = pc[:, valid_idx]
                invalid_points = pc[:, 1 - valid_idx]
                #perform knn to cover all points
                knn_index, _ = _F.knn_distance(invalid_points.unsqueeze(0),
                                               valid_points.unsqueeze(0), 5,
                                               False)
                invalid_pred, _ = pred_ins_label[valid_idx][
                    knn_index.squeeze()].mode()
                pred_ins_label[1 - valid_idx] = invalid_pred
            cur_mask_pool_new = torch.zeros([0, num_points]).cuda()
            for k in range(cur_mask_pool.shape[0]):
                if torch.sum(pred_ins_label == (k + 1)) != 0:
                    cur_mask_pool_new = torch.cat([
                        cur_mask_pool_new,
                        ((pred_ins_label == (k + 1)).float()).unsqueeze(0)
                    ],
                                                  dim=0)
            out_mask[iteration, :cur_mask_pool_new.shape[0]] = copy.deepcopy(
                cur_mask_pool_new.cpu().data.numpy().astype(np.bool))
            out_valid[iteration, :cur_mask_pool_new.shape[0]] = np.sum(
                cur_mask_pool_new.cpu().data.numpy()) > 10

    test_time = time.time() - start_time
    logger.info('Test {}  test time: {:.2f}s'.format(meters.summary_str,
                                                     test_time))
    for i in range(int(out_mask.shape[0] / 1024) + 1):
        save_h5(os.path.join(output_dir_save, 'test-%02d.h5' % (i)),
                out_mask[i * 1024:(i + 1) * 1024],
                out_valid[i * 1024:(i + 1) * 1024],
                out_conf[i * 1024:(i + 1) * 1024])
Пример #7
0
def train_one_epoch(model,
                    model_merge,
                    loss_fn,
                    metric,
                    dataloader,
                    cur_epoch,
                    optimizer,
                    optimizer_embed,
                    checkpointer_embed,
                    output_dir_merge,
                    max_grad_norm=0.0,
                    freezer=None,
                    log_period=-1):
    global xyz_pool1
    global xyz_pool2
    global context_xyz_pool1
    global context_xyz_pool2
    global context_context_xyz_pool
    global context_label_pool
    global context_purity_pool
    global label_pool
    global purity_purity_pool
    global purity_xyz_pool
    global policy_purity_pool
    global policy_reward_pool
    global policy_xyz_pool1
    global policy_xyz_pool2
    global old_iteration

    logger = logging.getLogger('shaper.train')
    meters = MetricLogger(delimiter='  ')
    metric.reset()
    meters.bind(metric)

    model.eval()
    loss_fn.eval()
    model_merge.eval()
    softmax = nn.Softmax()

    policy_total_bs = 8
    rnum = 1 if policy_total_bs-cur_epoch < 1 else policy_total_bs-cur_epoch
    end = time.time()

    buffer_txt = os.path.join(output_dir_merge, 'last_buffer')
    checkpoint_txt = os.path.join(output_dir_merge, 'last_checkpoint')
    if os.path.exists(checkpoint_txt):
        checkpoint_f = open(checkpoint_txt,'r')
        cur_checkpoint = checkpoint_f.read()
        checkpoint_f.close()
    else:
        cur_checkpoint = 'no_checkpoint'

    checkpointer_embed.load(None, resume=True)
    print('load checkpoint from %s'%cur_checkpoint)
    model_merge.eval()
    for iteration, data_batch in enumerate(dataloader):
        print('epoch: %d, iteration: %d, size of binary: %d, size of context: %d'%(cur_epoch, iteration, len(xyz_pool1), len(context_xyz_pool1)))
        sys.stdout.flush()
    #add conditions
        if os.path.exists(checkpoint_txt):
            checkpoint_f = open(checkpoint_txt,'r')
            new_checkpoint = checkpoint_f.read()
            checkpoint_f.close()
            if cur_checkpoint != new_checkpoint:
                checkpointer_embed.load(None, resume=True)
                cur_checkpoint = new_checkpoint
                print('load checkpoint from %s'%cur_checkpoint)
                model_merge.eval()

        data_time = time.time() - end

        data_batch = {k: v.cuda(non_blocking=True) for k, v in data_batch.items()}

        #predict box's coords
        with torch.no_grad():
            preds = model(data_batch)
            loss_dict = loss_fn(preds, data_batch)
        total_loss = sum(loss_dict.values())
        meters.update(loss=total_loss, **loss_dict)
        with torch.no_grad():
            # TODO add loss_dict hack
            metric.update_dict(preds, data_batch)

        #extraction box features
        batch_size, _, num_centroids, num_neighbours = data_batch['neighbour_xyz'].shape
        num_points = data_batch['points'].shape[-1]

        #batch_size, num_centroid, num_neighbor
        _, p = torch.max(preds['ins_logit'], 1)
        box_index_expand = torch.zeros((batch_size*num_centroids, num_points)).cuda()
        box_index_expand = box_index_expand.scatter_(dim=1, index=data_batch['neighbour_index'].reshape([-1, num_neighbours]), src=p.reshape([-1, num_neighbours]).float())
        centroid_label = data_batch['centroid_label'].reshape(-1)

        #remove proposal < minimum_num
        minimum_box_pc_num = 8
        minimum_overlap_pc_num = 8 #1/32 * num_neighbour
        gtmin_mask = (torch.sum(box_index_expand, dim=-1) > minimum_box_pc_num)

        #remove proposal whose purity score < 0.8
        box_label_expand = torch.zeros((batch_size*num_centroids, 200)).cuda()
        box_idx_expand = tile(data_batch['ins_id'],0,num_centroids).cuda()
        box_label_expand = box_label_expand.scatter_add_(dim=1, index=box_idx_expand, src=box_index_expand).float()
        maximum_label_num, maximum_label = torch.max(box_label_expand, 1)
        centroid_label = maximum_label
        total_num = torch.sum(box_label_expand, 1)
        box_purity = maximum_label_num / (total_num+1e-6)
        box_purity_mask = box_purity > 0.8
        box_purity_valid_mask = 1 - (box_purity < 0.8)*(box_purity > 0.6)
        box_purity_valid_mask *= gtmin_mask
        box_purity_valid_mask_l2 = gtmin_mask.long()#*data_batch['centroid_valid_mask'].reshape(-1).long()
        meters.update(purity_ratio = torch.sum(box_purity_mask).float()/box_purity_mask.shape[0], purity_valid_ratio=torch.sum(box_purity_valid_mask).float()/box_purity_mask.shape[0])
        meters.update(purity_pos_num = torch.sum(box_purity_mask), purity_neg_num = torch.sum(1-box_purity_mask), purity_neg_valid_num=torch.sum(box_purity<0.6))
        centroid_valid_mask = data_batch['centroid_valid_mask'].reshape(-1).long()
        meters.update(centroid_valid_purity_ratio = torch.sum(torch.index_select(box_purity_mask, dim=0, index=centroid_valid_mask.nonzero().squeeze())).float()/torch.sum(centroid_valid_mask),centroid_nonvalid_purity_ratio = torch.sum(torch.index_select(box_purity_mask, dim=0, index=(1-centroid_valid_mask).nonzero().squeeze())).float()/torch.sum(1-centroid_valid_mask))
        purity_pred = torch.zeros([0]).type(torch.FloatTensor).cuda()

        #update pool by valid_mask
        valid_mask = gtmin_mask.long() *  box_purity_mask.long() * (centroid_label!=0).long()
        centroid_label = torch.index_select(centroid_label, dim=0, index=valid_mask.nonzero().squeeze())

        box_num = torch.sum(valid_mask.reshape(batch_size, num_centroids),1)
        cumsum_box_num = torch.cumsum(box_num, dim=0)
        cumsum_box_num = torch.cat([torch.from_numpy(np.array(0)).cuda().unsqueeze(0),cumsum_box_num],dim=0)

        #initialization
        pc_all = data_batch['points']
        centroid_label_all = centroid_label.clone()
        sub_xyz_pool1 = torch.zeros([0,3,1024]).float().cuda()
        sub_xyz_pool2 = torch.zeros([0,3,1024]).float().cuda()
        sub_context_xyz_pool1 = torch.zeros([0,3,1024]).float().cuda()
        sub_context_xyz_pool2 = torch.zeros([0,3,1024]).float().cuda()
        sub_context_context_xyz_pool = torch.zeros([0,3,2048]).float().cuda()
        sub_context_label_pool = torch.zeros([0]).float().cuda()
        sub_context_purity_pool = torch.zeros([0]).float().cuda()
        sub_label_pool = torch.zeros([0]).float().cuda()
        sub_purity_pool = torch.zeros([0]).float().cuda()
        sub_purity_xyz_pool = torch.zeros([0,3,1024]).float().cuda()
        sub_policy_purity_pool = torch.zeros([0,policy_update_bs]).float().cuda()
        sub_policy_reward_pool = torch.zeros([0,policy_update_bs]).float().cuda()
        sub_policy_xyz_pool1 = torch.zeros([0,policy_update_bs,3,1024]).float().cuda()
        sub_policy_xyz_pool2 = torch.zeros([0,policy_update_bs,3,1024]).float().cuda()
        for i in range(pc_all.shape[0]):
            bs = policy_total_bs
            BS = policy_update_bs

            pc = pc_all[i].clone()
            cur_mask_pool = box_index_expand[cumsum_box_num[i]:cumsum_box_num[i+1]].clone()
            centroid_label = centroid_label_all[cumsum_box_num[i]:cumsum_box_num[i+1]].clone()
            cover_ratio = torch.unique(cur_mask_pool.nonzero()[:,1]).shape[0]/num_points
            cur_xyz_pool, xyz_mean = mask_to_xyz(pc, cur_mask_pool)
            init_pool_size = cur_xyz_pool.shape[0]
            meters.update(cover_ratio=cover_ratio, init_pool_size=init_pool_size)
            negative_num = 0
            positive_num = 0

            #intial adjacent matrix
            inter_matrix = torch.matmul(cur_mask_pool, cur_mask_pool.transpose(0, 1))
            inter_matrix_full = inter_matrix.clone()>minimum_overlap_pc_num
            inter_matrix[torch.eye(inter_matrix.shape[0]).byte()] = 0
            pair_idx = (inter_matrix.triu()>minimum_overlap_pc_num).nonzero()
            zero_pair = torch.ones([0,2]).long()

            small_flag = False

            model_merge.eval()
            with torch.no_grad():
                while pair_idx.shape[0] > 0:
                    #when there are too few pairs, we calculate the policy score matrix on all pairs
                    if pair_idx.shape[0] <= BS and small_flag == False:
                        small_flag = True
                        purity_matrix = torch.zeros(inter_matrix.shape).cuda()
                        policy_matrix = torch.zeros(inter_matrix.shape).cuda()
                        bsp = 64
                        idx = torch.arange(pair_idx.shape[0]).cuda()
                        purity_pool = torch.zeros([0]).float().cuda()
                        policy_pool = torch.zeros([0]).float().cuda()
                        for k in range(int(np.ceil(idx.shape[0]/bsp))):
                            sub_part_idx = torch.index_select(pair_idx, dim=0, index=idx[k*bsp:(k+1)*bsp])
                            part_xyz1 = torch.index_select(cur_xyz_pool, dim=0, index=sub_part_idx[:,0])
                            part_xyz2 = torch.index_select(cur_xyz_pool, dim=0, index=sub_part_idx[:,1])
                            part_xyz = torch.cat([part_xyz1,part_xyz2],-1)
                            part_xyz -= torch.mean(part_xyz,-1).unsqueeze(-1)
                            part_norm = part_xyz.norm(dim=1).max(dim=-1)[0].unsqueeze(-1).unsqueeze(-1)
                            part_xyz /= part_norm
                            logits_purity = model_merge(part_xyz, 'purity').squeeze()
                            if len(logits_purity.shape) == 0:
                                logits_purity = logits_purity.unsqueeze(0)
                            purity_pool = torch.cat([purity_pool, logits_purity], dim=0)

                            part_xyz11 = part_xyz1 - torch.mean(part_xyz1,-1).unsqueeze(-1)
                            part_xyz22 = part_xyz2 - torch.mean(part_xyz2,-1).unsqueeze(-1)
                            part_xyz11 /= part_norm
                            part_xyz22 /= part_norm
                            logits11 = model_merge(part_xyz11, 'policy')
                            logits22 = model_merge(part_xyz22, 'policy')
                            policy_scores = model_merge(torch.cat([logits11, logits22],dim=-1), 'policy_head').squeeze()
                            if len(policy_scores.shape) == 0:
                                policy_scores = policy_scores.unsqueeze(0)
                            policy_pool = torch.cat([policy_pool, policy_scores], dim=0)

                        purity_matrix[pair_idx[:,0],pair_idx[:,1]] = purity_pool
                        policy_matrix[pair_idx[:,0],pair_idx[:,1]] = policy_pool
                        score_matrix = torch.zeros(purity_matrix.shape).cuda()
                        score_matrix[pair_idx[:,0],pair_idx[:,1]] = softmax(purity_pool*policy_pool)

                    #if there are many pairs, we randomly sample a small batch of pairs and then compute the policy score matrix thereon to select pairs into the next stage
                    #else, we select a pair with highest policy score 
                    if pair_idx.shape[0] > BS and small_flag != True:
                        perm_idx = torch.randperm(pair_idx.shape[0]).cuda()
                        perm_idx_rnd = perm_idx[:bs]
                        sub_part_idx = torch.index_select(pair_idx, dim=0, index=perm_idx[:int(BS)])
                        part_xyz1 = torch.index_select(cur_xyz_pool, dim=0, index=sub_part_idx[:,0])
                        part_xyz2 = torch.index_select(cur_xyz_pool, dim=0, index=sub_part_idx[:,1])
                        part_xyz = torch.cat([part_xyz1,part_xyz2],-1)
                        part_xyz -= torch.mean(part_xyz,-1).unsqueeze(-1)
                        part_norm = part_xyz.norm(dim=1).max(dim=-1)[0].unsqueeze(-1).unsqueeze(-1)
                        part_xyz /= part_norm
                        logits_purity = model_merge(part_xyz, 'purity').squeeze()
                        sub_policy_purity_pool = torch.cat([sub_policy_purity_pool, logits_purity.detach().unsqueeze(0).clone()],dim=0)

                        part_xyz11 = part_xyz1 - torch.mean(part_xyz1,-1).unsqueeze(-1)
                        part_xyz22 = part_xyz2 - torch.mean(part_xyz2,-1).unsqueeze(-1)
                        part_xyz11 /= part_norm
                        part_xyz22 /= part_norm
                        logits11 = model_merge(part_xyz11, 'policy')
                        logits22 = model_merge(part_xyz22, 'policy')
                        policy_scores = model_merge(torch.cat([logits11, logits22],dim=-1), 'policy_head').squeeze()
                        sub_policy_xyz_pool1 = torch.cat([sub_policy_xyz_pool1, part_xyz11.unsqueeze(0).clone()], dim=0)
                        sub_policy_xyz_pool2 = torch.cat([sub_policy_xyz_pool2, part_xyz22.unsqueeze(0).clone()], dim=0)
                        if sub_policy_xyz_pool1.shape[0] > 64:
                            policy_xyz_pool1 = torch.cat([policy_xyz_pool1, sub_policy_xyz_pool1.cpu().clone()], dim=0)
                            policy_xyz_pool2 = torch.cat([policy_xyz_pool2, sub_policy_xyz_pool2.cpu().clone()], dim=0)
                            sub_policy_xyz_pool1 = torch.zeros([0,policy_update_bs,3,1024]).float().cuda()
                            sub_policy_xyz_pool2 = torch.zeros([0,policy_update_bs,3,1024]).float().cuda()
                        score = softmax(logits_purity*policy_scores)

                        part_label1 = torch.index_select(centroid_label, dim=0, index=sub_part_idx[:,0])
                        part_label2 = torch.index_select(centroid_label, dim=0, index=sub_part_idx[:,1])
                        siamese_label_gt = (part_label1 == part_label2)*(1 - (part_label1 == -1))*(1 - (part_label2 == -1))*(logits_purity>0.8)
                        sub_policy_reward_pool = torch.cat([sub_policy_reward_pool, siamese_label_gt.unsqueeze(0).float().clone()], dim=0)
                        loss_policy = -torch.sum(score*(siamese_label_gt.float()))
                        meters.update(loss_policy =loss_policy)

                        #we also introduce certain random samples to encourage exploration
                        _, rank_idx = torch.topk(score,bs,largest=True,sorted=False)
                        perm_idx = perm_idx[rank_idx]
                        perm_idx = torch.cat([perm_idx[:policy_total_bs-rnum], perm_idx_rnd[:rnum]], dim=0)
                        if cur_epoch == 1 and iteration < 128:
                            perm_idx = torch.randperm(pair_idx.shape[0]).cuda()
                            perm_idx = perm_idx[:policy_total_bs]
                    else:
                        score = score_matrix[pair_idx[:,0],pair_idx[:,1]]
                        _, rank_idx = torch.topk(score,1,largest=True,sorted=False)
                        perm_idx = rank_idx
                        if len(perm_idx.shape) == 0:
                            perm_idx = perm_idx.unsqueeze(0)

                        if cur_epoch == 1 and iteration < 128:
                            perm_idx = torch.randperm(pair_idx.shape[0]).cuda()
                            perm_idx = perm_idx[:1]

                    #send the selected pairs into verification network
                    sub_part_idx = torch.index_select(pair_idx, dim=0, index=perm_idx[:bs])
                    part_xyz1 = torch.index_select(cur_xyz_pool, dim=0, index=sub_part_idx[:,0])
                    part_xyz2 = torch.index_select(cur_xyz_pool, dim=0, index=sub_part_idx[:,1])
                    part_mask11 = torch.index_select(cur_mask_pool, dim=0, index=sub_part_idx[:,0])
                    part_mask22 = torch.index_select(cur_mask_pool, dim=0, index=sub_part_idx[:,1])
                    part_label1 = torch.index_select(centroid_label, dim=0, index=sub_part_idx[:,0])
                    part_label2 = torch.index_select(centroid_label, dim=0, index=sub_part_idx[:,1])
                    new_part_mask = 1-(1-part_mask11)*(1-part_mask22)
                    box_label_expand = torch.zeros((new_part_mask.shape[0], 200)).cuda()
                    box_idx_expand = tile(data_batch['ins_id'][i].unsqueeze(0),0,new_part_mask.shape[0]).cuda()
                    box_label_expand = box_label_expand.scatter_add_(dim=1, index=box_idx_expand, src=new_part_mask).float()
                    maximum_label_num, maximum_label = torch.max(box_label_expand, 1)
                    total_num = torch.sum(box_label_expand, 1)
                    box_purity = maximum_label_num / (total_num+1e-6)
                    sub_purity_pool = torch.cat([sub_purity_pool, box_purity.clone()], dim=0)
                    purity_xyz, xyz_mean = mask_to_xyz(pc, new_part_mask)
                    purity_xyz -= xyz_mean
                    purity_xyz /=(purity_xyz+1e-6).norm(dim=1).max(dim=-1)[0].unsqueeze(-1).unsqueeze(-1)
                    sub_purity_xyz_pool = torch.cat([sub_purity_xyz_pool, purity_xyz.clone()],dim=0)

                    siamese_label_gt = (part_label1 == part_label2)*(1 - (part_label1 == -1))*(1 - (part_label2 == -1))*(box_purity > 0.8)
                    negative_num += torch.sum(siamese_label_gt == 0)
                    positive_num += torch.sum(siamese_label_gt == 1)

                    #save data
                    sub_xyz_pool1 = torch.cat([sub_xyz_pool1, part_xyz1.clone()], dim=0)
                    sub_xyz_pool2 = torch.cat([sub_xyz_pool2, part_xyz2.clone()], dim=0)
                    sub_label_pool = torch.cat([sub_label_pool, siamese_label_gt.clone().float()], dim=0)

                    #renorm
                    part_xyz = torch.cat([part_xyz1,part_xyz2],-1)
                    part_xyz -= torch.mean(part_xyz,-1).unsqueeze(-1)
                    part_xyz11 = part_xyz1 - torch.mean(part_xyz1,-1).unsqueeze(-1)
                    part_xyz22 = part_xyz2 - torch.mean(part_xyz2,-1).unsqueeze(-1)
                    part_xyz11 /=part_xyz11.norm(dim=1).max(dim=-1)[0].unsqueeze(-1).unsqueeze(-1)
                    part_xyz22 /=part_xyz22.norm(dim=1).max(dim=-1)[0].unsqueeze(-1).unsqueeze(-1)
                    part_xyz /=part_xyz.norm(dim=1).max(dim=-1)[0].unsqueeze(-1).unsqueeze(-1)

                    #save data
                    if cur_xyz_pool.shape[0] <= 32:
                        context_idx1 = torch.index_select(inter_matrix_full,dim=0,index=sub_part_idx[:,0])
                        context_idx2 = torch.index_select(inter_matrix_full,dim=0,index=sub_part_idx[:,1])
                        context_mask1 = (torch.matmul(context_idx1.float(), cur_mask_pool)>0).float()
                        context_mask2 = (torch.matmul(context_idx2.float(), cur_mask_pool)>0).float()
                        context_mask = ((context_mask1+context_mask2)>0).float()
                        context_xyz, xyz_mean = mask_to_xyz(pc, context_mask, sample_num=2048)
                        context_xyz = context_xyz - xyz_mean
                        context_xyz /= context_xyz.norm(dim=1).max(dim=-1)[0].unsqueeze(-1).unsqueeze(-1)
                        sub_context_context_xyz_pool = torch.cat([sub_context_context_xyz_pool, context_xyz.clone()], dim=0)
                        sub_context_xyz_pool1 = torch.cat([sub_context_xyz_pool1, part_xyz1.clone()], dim=0)
                        sub_context_xyz_pool2 = torch.cat([sub_context_xyz_pool2, part_xyz2.clone()], dim=0)
                        sub_context_label_pool = torch.cat([sub_context_label_pool, siamese_label_gt.clone().float()], dim=0)
                        sub_context_purity_pool =  torch.cat([sub_context_purity_pool, box_purity.clone()], dim=0)

                    #at the very beginning, we group pairs according to ground-truth
                    if (cur_epoch == 1 and iteration < 128) or (cur_checkpoint == 'no_checkpoint'):
                        siamese_label = (part_label1 == part_label2)
                    #if we have many sub-parts in the pool, we use the binary branch to predict
                    elif cur_xyz_pool.shape[0] > 32:
                        logits1 = model_merge(part_xyz11,'backbone')
                        logits2 = model_merge(part_xyz22,'backbone')
                        merge_logits = model_merge(torch.cat([part_xyz, torch.cat([logits1.unsqueeze(-1).expand(-1,-1,part_xyz1.shape[-1]), logits2.unsqueeze(-1).expand(-1,-1,part_xyz2.shape[-1])], dim=-1)], dim=1), 'head')
                        _, p = torch.max(merge_logits, 1)
                        siamese_label = p
                    #if there are too few sub-parts in the pool, we use the context branch to predict
                    else:
                        logits1 = model_merge(part_xyz11,'backbone')
                        logits2 = model_merge(part_xyz22,'backbone')
                        context_logits = model_merge(context_xyz,'backbone2')
                        merge_logits = model_merge(torch.cat([part_xyz, torch.cat([logits1.unsqueeze(-1).expand(-1,-1,part_xyz1.shape[-1]), logits2.unsqueeze(-1).expand(-1,-1,part_xyz2.shape[-1])], dim=-1), torch.cat([context_logits.unsqueeze(-1).expand(-1,-1,part_xyz.shape[-1])], dim=-1)], dim=1), 'head2')
                        _, p = torch.max(merge_logits, 1)
                        siamese_label = p


                    #group sub-parts according to the prediction
                    merge_idx1 = torch.index_select(sub_part_idx[:,0], dim=0, index=siamese_label.nonzero().squeeze())
                    merge_idx2 = torch.index_select(sub_part_idx[:,1], dim=0, index=siamese_label.nonzero().squeeze())
                    merge_idx = torch.unique(torch.cat([merge_idx1, merge_idx2], dim=0))
                    nonmerge_idx1 = torch.index_select(sub_part_idx[:,0], dim=0, index=(1-siamese_label).nonzero().squeeze())
                    nonmerge_idx2 = torch.index_select(sub_part_idx[:,1], dim=0, index=(1-siamese_label).nonzero().squeeze())
                    part_mask1 = torch.index_select(cur_mask_pool, dim=0, index=merge_idx1)
                    part_mask2 = torch.index_select(cur_mask_pool, dim=0, index=merge_idx2)
                    new_part_mask = 1-(1-part_mask1)*(1-part_mask2)
                    new_part_label = torch.index_select(part_label1, dim=0, index=siamese_label.nonzero().squeeze()).long()
                    new_part_label_invalid = torch.index_select(siamese_label_gt, dim=0, index=siamese_label.nonzero().squeeze()).long()
                    new_part_label = new_part_label*new_part_label_invalid + -1*(1-new_part_label_invalid)

                    #sometimes, we may obtain several identical sub-parts
                    #for those, we only keep one
                    equal_matrix = torch.matmul(new_part_mask,1-new_part_mask.transpose(0,1))+torch.matmul(1-new_part_mask,new_part_mask.transpose(0,1))
                    equal_matrix[torch.eye(equal_matrix.shape[0]).byte()]=1
                    fid = (equal_matrix==0).nonzero()
                    if fid.shape[0] > 0:
                        flag = torch.ones(equal_matrix.shape[0])
                        for k in range(flag.shape[0]):
                            if flag[k] != 0:
                                flag[fid[:,1][fid[:,0]==k]] = 0
                        new_part_mask = torch.index_select(new_part_mask, dim=0, index=flag.nonzero().squeeze().cuda())
                        new_part_label = torch.index_select(new_part_label, dim=0, index=flag.nonzero().squeeze().cuda())

                    new_part_xyz, xyz_mean = mask_to_xyz(pc, new_part_mask)

                    #when there are too few pairs, update the policy score matrix so that we do not need to calculate the whole matrix everytime
                    if small_flag and (new_part_mask.shape[0] > 0):
                        overlap_idx = (torch.matmul(cur_mask_pool, new_part_mask.transpose(0,1))>minimum_overlap_pc_num).nonzero().squeeze()
                        if overlap_idx.shape[0] > 0:
                            if len(overlap_idx.shape) == 1:
                                overlap_idx = overlap_idx.unsqueeze(0)
                            part_xyz1 = torch.index_select(cur_xyz_pool, dim=0, index=overlap_idx[:,0])
                            part_xyz2 = tile(new_part_xyz, 0, overlap_idx.shape[0])
                            part_xyz = torch.cat([part_xyz1,part_xyz2],-1)
                            part_xyz -= torch.mean(part_xyz,-1).unsqueeze(-1)
                            part_norm = part_xyz.norm(dim=1).max(dim=-1)[0].unsqueeze(-1).unsqueeze(-1)
                            part_xyz /= part_norm
                            overlap_purity_scores = model_merge(part_xyz, 'purity').squeeze()

                            part_xyz11 = part_xyz1 - torch.mean(part_xyz1,-1).unsqueeze(-1)
                            part_xyz22 = part_xyz2 - torch.mean(part_xyz2,-1).unsqueeze(-1)
                            part_xyz11 /= part_norm
                            part_xyz22 /= part_norm
                            logits11 = model_merge(part_xyz11, 'policy')
                            logits22 = model_merge(part_xyz22, 'policy')
                            overlap_policy_scores = model_merge(torch.cat([logits11, logits22],dim=-1), 'policy_head').squeeze()

                            tmp_purity_arr = torch.zeros([purity_matrix.shape[0]]).cuda()
                            tmp_policy_arr = torch.zeros([policy_matrix.shape[0]]).cuda()
                            tmp_purity_arr[overlap_idx[:,0]] = overlap_purity_scores
                            tmp_policy_arr[overlap_idx[:,0]] = overlap_policy_scores
                            purity_matrix = torch.cat([purity_matrix,tmp_purity_arr.unsqueeze(1)],dim=1)
                            policy_matrix = torch.cat([policy_matrix,tmp_policy_arr.unsqueeze(1)],dim=1)
                            purity_matrix = torch.cat([purity_matrix,torch.zeros(purity_matrix.shape[1]).cuda().unsqueeze(0)])
                            policy_matrix = torch.cat([policy_matrix,torch.zeros(policy_matrix.shape[1]).cuda().unsqueeze(0)])
                        else:
                            purity_matrix = torch.cat([purity_matrix,torch.zeros(purity_matrix.shape[0]).cuda().unsqueeze(1)],dim=1)
                            policy_matrix = torch.cat([policy_matrix,torch.zeros(policy_matrix.shape[0]).cuda().unsqueeze(1)],dim=1)
                            purity_matrix = torch.cat([purity_matrix,torch.zeros(purity_matrix.shape[1]).cuda().unsqueeze(0)])
                            policy_matrix = torch.cat([policy_matrix,torch.zeros(policy_matrix.shape[1]).cuda().unsqueeze(0)])

                    #update cur_pool, add new parts, pick out merged input pairs
                    cur_mask_pool = torch.cat([cur_mask_pool, new_part_mask], dim=0)
                    cur_xyz_pool = torch.cat([cur_xyz_pool, new_part_xyz], dim=0)
                    centroid_label = torch.cat([centroid_label, new_part_label], dim=0)
                    cur_pool_size = cur_mask_pool.shape[0]
                    new_mask = torch.ones([cur_pool_size])
                    new_mask[merge_idx] = 0
                    new_idx = new_mask.nonzero().squeeze().cuda()
                    cur_xyz_pool = torch.index_select(cur_xyz_pool, dim=0, index=new_idx)
                    cur_mask_pool = torch.index_select(cur_mask_pool, dim=0, index=new_idx)
                    centroid_label = torch.index_select(centroid_label, dim=0, index=new_idx)
                    inter_matrix = torch.matmul(cur_mask_pool, cur_mask_pool.transpose(0, 1))
                    inter_matrix_full = inter_matrix.clone()>minimum_overlap_pc_num
                    #update zero_matrix
                    zero_matrix = torch.zeros([cur_pool_size, cur_pool_size])
                    zero_matrix[zero_pair[:,0], zero_pair[:,1]] = 1
                    zero_matrix[nonmerge_idx1, nonmerge_idx2] = 1
                    zero_matrix[nonmerge_idx2, nonmerge_idx1] = 1
                    zero_matrix = torch.index_select(zero_matrix, dim=0, index=new_idx.cpu())
                    zero_matrix = torch.index_select(zero_matrix, dim=1, index=new_idx.cpu())
                    zero_pair = zero_matrix.nonzero()
                    inter_matrix[zero_pair[:,0], zero_pair[:,1]] = 0
                    inter_matrix[torch.eye(inter_matrix.shape[0]).byte()] = 0
                    pair_idx = (inter_matrix.triu()>minimum_overlap_pc_num).nonzero()
                    if small_flag == True:
                        purity_matrix = torch.index_select(purity_matrix, dim=0, index=new_idx)
                        purity_matrix = torch.index_select(purity_matrix, dim=1, index=new_idx)
                        policy_matrix = torch.index_select(policy_matrix, dim=0, index=new_idx)
                        policy_matrix = torch.index_select(policy_matrix, dim=1, index=new_idx)
                        score_matrix = torch.zeros(purity_matrix.shape).cuda()
                        score_idx = pair_idx
                        score_matrix[score_idx[:,0], score_idx[:,1]] = softmax(purity_matrix[score_idx[:,0], score_idx[:,1]] * policy_matrix[score_idx[:,0], score_idx[:,1]])
                final_pool_size = negative_num + positive_num
                meters.update(final_pool_size=final_pool_size,negative_num=negative_num, positive_num=positive_num)
        xyz_pool1 = torch.cat([xyz_pool1, sub_xyz_pool1.cpu().clone()],dim=0)
        xyz_pool2 = torch.cat([xyz_pool2, sub_xyz_pool2.cpu().clone()],dim=0)
        label_pool = torch.cat([label_pool, sub_label_pool.cpu().clone()], dim=0)
        context_context_xyz_pool = torch.cat([context_context_xyz_pool, sub_context_context_xyz_pool.cpu().clone()],dim=0)
        context_xyz_pool1 = torch.cat([context_xyz_pool1, sub_context_xyz_pool1.cpu().clone()],dim=0)
        context_xyz_pool2 = torch.cat([context_xyz_pool2, sub_context_xyz_pool2.cpu().clone()],dim=0)
        context_label_pool = torch.cat([context_label_pool, sub_context_label_pool.cpu().clone()], dim=0)
        context_purity_pool = torch.cat([context_purity_pool, sub_context_purity_pool.cpu().clone()], dim=0)
        purity_purity_pool = torch.cat([purity_purity_pool, sub_purity_pool.cpu().clone()], dim=0)
        purity_xyz_pool = torch.cat([purity_xyz_pool, sub_purity_xyz_pool.cpu().clone()], dim=0)
        policy_purity_pool = torch.cat([policy_purity_pool, sub_policy_purity_pool.cpu().clone()], dim=0)
        policy_reward_pool = torch.cat([policy_reward_pool, sub_policy_reward_pool.cpu().clone()], dim=0)
        policy_xyz_pool1 = torch.cat([policy_xyz_pool1, sub_policy_xyz_pool1.cpu().clone()], dim=0)
        policy_xyz_pool2 = torch.cat([policy_xyz_pool2, sub_policy_xyz_pool2.cpu().clone()], dim=0)
        produce_time = time.time() - end

        #condition
        if context_xyz_pool1.shape[0] > 10000:
            rbuffer = dict()
            rbuffer['xyz_pool1'] = xyz_pool1
            rbuffer['xyz_pool2'] = xyz_pool2
            rbuffer['context_xyz_pool1'] = context_xyz_pool1
            rbuffer['context_xyz_pool2'] = context_xyz_pool2
            rbuffer['context_context_xyz_pool'] = context_context_xyz_pool
            rbuffer['context_label_pool'] = context_label_pool
            rbuffer['context_purity_pool'] = context_purity_pool
            rbuffer['label_pool'] = label_pool
            rbuffer['purity_purity_pool'] = purity_purity_pool
            rbuffer['purity_xyz_pool'] = purity_xyz_pool
            rbuffer['policy_purity_pool'] = policy_purity_pool
            rbuffer['policy_reward_pool'] = policy_reward_pool
            rbuffer['policy_xyz_pool1'] = policy_xyz_pool1
            rbuffer['policy_xyz_pool2'] = policy_xyz_pool2
            torch.save(rbuffer, os.path.join(output_dir_merge, 'buffer', '%d_%d.pt'%(cur_epoch, iteration)))
            buffer_f = open(buffer_txt, 'w')
            buffer_f.write('%d_%d'%(cur_epoch, iteration))
            buffer_f.close()
            p = Popen('rm -rf %s'%(os.path.join(output_dir_merge, 'buffer', '%d_%d.pt'%(cur_epoch, old_iteration))), shell=True)
            old_iteration = iteration
            p = Popen('rm -rf %s_*'%(os.path.join(output_dir_merge, 'buffer', '%d'%(cur_epoch-1))), shell=True)

            xyz_pool1 = torch.zeros([0,3,1024]).float()
            xyz_pool2 = torch.zeros([0,3,1024]).float()
            context_xyz_pool1 = torch.zeros([0,3,1024]).float()
            context_xyz_pool2 = torch.zeros([0,3,1024]).float()
            context_context_xyz_pool = torch.zeros([0,3,2048]).float()
            context_label_pool = torch.zeros([0]).float()
            context_purity_pool = torch.zeros([0]).float()
            label_pool = torch.zeros([0]).float()
            purity_purity_pool = torch.zeros([0]).float()
            purity_xyz_pool = torch.zeros([0,3,1024]).float()
            policy_purity_pool = torch.zeros([0,policy_update_bs]).float()
            policy_reward_pool = torch.zeros([0,policy_update_bs]).float()
            policy_xyz_pool1 = torch.zeros([0,policy_update_bs,3,1024]).float()
            policy_xyz_pool2 = torch.zeros([0,policy_update_bs,3,1024]).float()



        batch_time = time.time() - end
        end = time.time()
        meters.update(data_time=data_time, produce_time=produce_time)

        if log_period > 0 and iteration % log_period == 0:
            logger.info(
                meters.delimiter.join(
                    [
                        'iter: {iter:4d}',
                        '{meters}',
                        'max mem: {memory:.0f}',
                    ]
                ).format(
                    iter=iteration,
                    meters=str(meters),
                    memory=torch.cuda.max_memory_allocated() / (1024.0 ** 2),
                )
            )
    return meters
Пример #8
0
def train(cfg, local_rank, distributed):
    logger = logging.getLogger("SelfSupervised.trainer")
    logger.info("Start training")

    feature_extractor = build_feature_extractor(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    feature_extractor.to(device)

    classifier = build_classifier(cfg)
    classifier.to(device)

    if local_rank == 0:
        print(feature_extractor)
        print(classifier)

    batch_size = cfg.SOLVER.BATCH_SIZE
    if distributed:
        pg1 = torch.distributed.new_group(
            range(torch.distributed.get_world_size()))

        batch_size = int(cfg.SOLVER.BATCH_SIZE /
                         torch.distributed.get_world_size())
        if not cfg.MODEL.FREEZE_BN:
            feature_extractor = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                feature_extractor)
        feature_extractor = torch.nn.parallel.DistributedDataParallel(
            feature_extractor,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True,
            process_group=pg1)
        pg2 = torch.distributed.new_group(
            range(torch.distributed.get_world_size()))
        classifier = torch.nn.parallel.DistributedDataParallel(
            classifier,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True,
            process_group=pg2)
        torch.autograd.set_detect_anomaly(True)
        torch.distributed.barrier()

    optimizer_fea = torch.optim.SGD(feature_extractor.parameters(),
                                    lr=cfg.SOLVER.BASE_LR,
                                    momentum=cfg.SOLVER.MOMENTUM,
                                    weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    optimizer_fea.zero_grad()

    optimizer_cls = torch.optim.SGD(classifier.parameters(),
                                    lr=cfg.SOLVER.BASE_LR * 10,
                                    momentum=cfg.SOLVER.MOMENTUM,
                                    weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    optimizer_cls.zero_grad()

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = local_rank == 0

    iteration = 0

    if cfg.resume:
        logger.info("Loading checkpoint from {}".format(cfg.resume))
        checkpoint = torch.load(cfg.resume, map_location=torch.device('cpu'))
        model_weights = checkpoint[
            'feature_extractor'] if distributed else strip_prefix_if_present(
                checkpoint['feature_extractor'], 'module.')
        feature_extractor.load_state_dict(model_weights)
        classifier_weights = checkpoint[
            'classifier'] if distributed else strip_prefix_if_present(
                checkpoint['classifier'], 'module.')
        classifier.load_state_dict(classifier_weights)

    src_train_data = build_dataset(cfg, mode='train', is_source=True)

    if distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            src_train_data)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(src_train_data,
                                               batch_size=batch_size,
                                               shuffle=(train_sampler is None),
                                               num_workers=4,
                                               pin_memory=True,
                                               sampler=train_sampler,
                                               drop_last=True)

    ce_criterion = torch.nn.CrossEntropyLoss(ignore_index=255)

    max_iters = cfg.SOLVER.MAX_ITER
    logger.info("Start training")
    meters = MetricLogger(delimiter="  ")
    feature_extractor.train()
    classifier.train()
    start_training_time = time.time()
    end = time.time()

    for i, (src_input, src_label, _) in enumerate(train_loader):
        data_time = time.time() - end
        current_lr = adjust_learning_rate(cfg.SOLVER.LR_METHOD,
                                          cfg.SOLVER.BASE_LR,
                                          iteration,
                                          max_iters,
                                          power=cfg.SOLVER.LR_POWER)
        for index in range(len(optimizer_fea.param_groups)):
            optimizer_fea.param_groups[index]['lr'] = current_lr
        for index in range(len(optimizer_cls.param_groups)):
            optimizer_cls.param_groups[index]['lr'] = current_lr * 10

        optimizer_fea.zero_grad()
        optimizer_cls.zero_grad()
        src_input = src_input.cuda(non_blocking=True)
        src_label = src_label.cuda(non_blocking=True).long()
        size = src_label.shape[-2:]
        pred = classifier(feature_extractor(src_input), size)
        loss = ce_criterion(pred, src_label)
        loss.backward()

        optimizer_fea.step()
        optimizer_cls.step()
        meters.update(loss_seg=loss.item())
        iteration += 1

        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time, data=data_time)
        eta_seconds = meters.time.global_avg * (cfg.SOLVER.STOP_ITER -
                                                iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
        if iteration % 20 == 0 or iteration == max_iters:
            logger.info(
                meters.delimiter.join([
                    "eta: {eta}",
                    "iter: {iter}",
                    "{meters}",
                    "lr: {lr:.6f}",
                    "max mem: {memory:.2f} GB",
                ]).format(
                    eta=eta_string,
                    iter=iteration,
                    meters=str(meters),
                    lr=optimizer_fea.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 /
                    1024.0 / 1024.0,
                ))

        if (iteration % cfg.SOLVER.CHECKPOINT_PERIOD == 0
                or iteration == max_iters) and save_to_disk:
            filename = os.path.join(output_dir,
                                    "model_iter{:06d}.pth".format(iteration))
            torch.save(
                {
                    'iteration': iteration,
                    'feature_extractor': feature_extractor.state_dict(),
                    'classifier': classifier.state_dict(),
                    'optimizer_fea': optimizer_fea.state_dict(),
                    'optimizer_cls': optimizer_cls.state_dict()
                }, filename)

        if iteration == cfg.SOLVER.MAX_ITER:
            break
        if iteration == cfg.SOLVER.STOP_ITER:
            break

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info("Total training time: {} ({:.4f} s / it)".format(
        total_time_str, total_training_time / cfg.SOLVER.STOP_ITER))

    return feature_extractor, classifier
Пример #9
0
def train_one_epoch(model,
                    loss_fn,
                    metric,
                    dataloader,
                    optimizer,
                    max_grad_norm=0.0,
                    freezer=None,
                    log_period=-1):
    logger = logging.getLogger('shaper.train')
    meters = MetricLogger(delimiter='  ')
    # reset metrics
    metric.reset()
    meters.bind(metric)
    # set training mode
    model.train()
    if freezer is not None:
        freezer.freeze()
    loss_fn.train()
    metric.train()

    end = time.time()
    for iteration, data_batch in enumerate(dataloader):
        data_time = time.time() - end

        data_batch = {
            k: v.cuda(non_blocking=True)
            for k, v in data_batch.items()
        }

        preds = model(data_batch)

        # backward
        optimizer.zero_grad()
        loss_dict = loss_fn(preds, data_batch)
        total_loss = sum(loss_dict.values())

        # It is slightly faster to update metrics and meters before backward
        meters.update(loss=total_loss, **loss_dict)
        with torch.no_grad():
            metric.update_dict(preds, data_batch)

        total_loss.backward()
        if max_grad_norm > 0:
            # CAUTION: built-in clip_grad_norm_ clips the total norm.
            nn.utils.clip_grad_norm_(model.parameters(),
                                     max_norm=max_grad_norm)
        optimizer.step()

        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time, data=data_time)

        if log_period > 0 and iteration % log_period == 0:
            logger.info(
                meters.delimiter.join([
                    'iter: {iter:4d}',
                    '{meters}',
                    'lr: {lr:.2e}',
                    'max mem: {memory:.0f}',
                ]).format(
                    iter=iteration,
                    meters=str(meters),
                    lr=optimizer.param_groups[0]['lr'],
                    memory=torch.cuda.max_memory_allocated() / (1024.0**2),
                ))
    return meters
Пример #10
0
def train_one_epoch(model_merge,
                    cur_epoch,
                    optimizer_embed,
                    output_dir_merge,
                    max_grad_norm=0.0,
                    freezer=None,
                    log_period=-1):
    global xyz_pool1
    global xyz_pool2
    global context_xyz_pool1
    global context_xyz_pool2
    global context_context_xyz_pool
    global context_label_pool
    global context_purity_pool
    global label_pool
    global purity_purity_pool
    global purity_xyz_pool
    global policy_purity_pool
    global policy_reward_pool
    global policy_xyz_pool1
    global policy_xyz_pool2
    global cur_buffer
    global rbuffer
    global count

    logger = logging.getLogger('shaper.train')
    meters = MetricLogger(delimiter='  ')

    softmax = nn.Softmax()
    end = time.time()
    model_merge.train()
    sys.stdout.flush()
    BS = policy_update_bs
    print('epoch: %d' % cur_epoch)

    #delete older models
    if cur_epoch > 2:
        if (cur_epoch - 2) % 400 != 0:
            p = Popen('rm %s' %
                      (os.path.join(output_dir_merge, 'model_%03d.pth' %
                                    (cur_epoch - 2))),
                      shell=True)

    #keep reading newest data generated by producer
    while True:
        buffer_txt = os.path.join(output_dir_merge, 'last_buffer')
        buffer_file = open(buffer_txt, 'r')
        new_buffer = buffer_file.read()
        buffer_file.close()

        if new_buffer != 'start':
            if cur_buffer != new_buffer:
                count = 0
                cur_buffer = new_buffer
                print('read data from %s' % cur_buffer)
                rbuffer = torch.load(
                    os.path.join(output_dir_merge, 'buffer',
                                 '%s.pt' % cur_buffer))
                break
            count += 1
            if count <= 2:
                break
        time.sleep(10)

    #read data
    xyz_pool1 = rbuffer['xyz_pool1']
    xyz_pool2 = rbuffer['xyz_pool2']
    context_xyz_pool1 = rbuffer['context_xyz_pool1']
    context_xyz_pool2 = rbuffer['context_xyz_pool2']
    context_context_xyz_pool = rbuffer['context_context_xyz_pool']
    context_label_pool = rbuffer['context_label_pool']
    context_purity_pool = rbuffer['context_purity_pool']
    label_pool = rbuffer['label_pool']
    purity_purity_pool = rbuffer['purity_purity_pool']
    purity_xyz_pool = rbuffer['purity_xyz_pool']
    policy_purity_pool = rbuffer['policy_purity_pool']
    policy_reward_pool = rbuffer['policy_reward_pool']
    policy_xyz_pool1 = rbuffer['policy_xyz_pool1']
    policy_xyz_pool2 = rbuffer['policy_xyz_pool2']
    for i in range(20):
        bs2 = 64
        TRAIN_LEN = 1024
        UP_policy = 2048
        TRAIN_LEN_policy = 32
        bs_policy = int(128 / policy_update_bs)

        #train binary branch
        cur_len = xyz_pool1.shape[0]
        cur_train_len = TRAIN_LEN if cur_len > TRAIN_LEN else cur_len
        perm_idx = torch.randperm(cur_len)
        logits1_all = torch.zeros([0]).type(torch.LongTensor).cuda()
        sub_xyz_pool1 = torch.index_select(xyz_pool1,
                                           dim=0,
                                           index=perm_idx[:cur_train_len])
        sub_xyz_pool2 = torch.index_select(xyz_pool2,
                                           dim=0,
                                           index=perm_idx[:cur_train_len])
        sub_label_pool = torch.index_select(label_pool,
                                            dim=0,
                                            index=perm_idx[:cur_train_len])
        perm_idx = torch.arange(cur_train_len)
        for i in range(int(cur_train_len / bs2)):
            optimizer_embed.zero_grad()
            part_xyz1 = torch.index_select(sub_xyz_pool1,
                                           dim=0,
                                           index=perm_idx[i * bs2:(i + 1) *
                                                          bs2]).cuda()
            part_xyz2 = torch.index_select(sub_xyz_pool2,
                                           dim=0,
                                           index=perm_idx[i * bs2:(i + 1) *
                                                          bs2]).cuda()
            siamese_label = torch.index_select(sub_label_pool,
                                               dim=0,
                                               index=perm_idx[i * bs2:(i + 1) *
                                                              bs2]).cuda()
            part_xyz = torch.cat([part_xyz1, part_xyz2], -1)
            part_xyz -= torch.mean(part_xyz, -1).unsqueeze(-1)
            part_xyz1 -= torch.mean(part_xyz1, -1).unsqueeze(-1)
            part_xyz2 -= torch.mean(part_xyz2, -1).unsqueeze(-1)
            part_xyz1 /= part_xyz1.norm(dim=1).max(
                dim=-1)[0].unsqueeze(-1).unsqueeze(-1)
            part_xyz2 /= part_xyz2.norm(dim=1).max(
                dim=-1)[0].unsqueeze(-1).unsqueeze(-1)
            part_xyz /= part_xyz.norm(dim=1).max(
                dim=-1)[0].unsqueeze(-1).unsqueeze(-1)
            logits1 = model_merge(part_xyz1, 'backbone')
            logits2 = model_merge(part_xyz2, 'backbone')
            merge_logits = model_merge(
                torch.cat([
                    part_xyz,
                    torch.cat([
                        logits1.unsqueeze(-1).expand(-1, -1,
                                                     part_xyz1.shape[-1]),
                        logits2.unsqueeze(-1).expand(-1, -1,
                                                     part_xyz2.shape[-1])
                    ],
                              dim=-1)
                ],
                          dim=1), 'head')
            _, p = torch.max(merge_logits, 1)
            logits1_all = torch.cat([logits1_all, p], dim=0)
            merge_acc_arr = (p == siamese_label.long()).float()
            meters.update(meters_acc=torch.mean(merge_acc_arr))
            if torch.sum(siamese_label) != 0:
                merge_pos_acc = torch.mean(
                    torch.index_select(
                        merge_acc_arr,
                        dim=0,
                        index=(siamese_label == 1).nonzero().squeeze()))
                meters.update(merge_pos_acc=merge_pos_acc)
            if torch.sum(1 - siamese_label) != 0:
                merge_neg_acc = torch.mean(
                    torch.index_select(
                        merge_acc_arr,
                        dim=0,
                        index=(siamese_label == 0).nonzero().squeeze()))
                meters.update(merge_neg_acc=merge_neg_acc)

            loss_sim = cross_entropy(merge_logits, siamese_label.long())

            loss_dict_embed = {
                'loss_sim': loss_sim,
            }
            meters.update(**loss_dict_embed)
            total_loss_embed = sum(loss_dict_embed.values())
            total_loss_embed.backward()
            optimizer_embed.step()

        #train context branch
        cur_len = context_xyz_pool1.shape[0]
        cur_train_len = TRAIN_LEN if cur_len > TRAIN_LEN else cur_len
        perm_idx = torch.randperm(cur_len)
        logits1_all = torch.zeros([0]).type(torch.LongTensor).cuda()
        sub_xyz_pool1 = torch.index_select(context_xyz_pool1,
                                           dim=0,
                                           index=perm_idx[:cur_train_len])
        sub_xyz_pool2 = torch.index_select(context_xyz_pool2,
                                           dim=0,
                                           index=perm_idx[:cur_train_len])
        sub_label_pool = torch.index_select(context_label_pool,
                                            dim=0,
                                            index=perm_idx[:cur_train_len])
        sub_context_context_xyz_pool = torch.index_select(
            context_context_xyz_pool, dim=0, index=perm_idx[:cur_train_len])
        perm_idx = torch.arange(cur_train_len)
        for i in range(int(cur_train_len / bs2)):
            optimizer_embed.zero_grad()
            part_xyz1 = torch.index_select(sub_xyz_pool1,
                                           dim=0,
                                           index=perm_idx[i * bs2:(i + 1) *
                                                          bs2]).cuda()
            part_xyz2 = torch.index_select(sub_xyz_pool2,
                                           dim=0,
                                           index=perm_idx[i * bs2:(i + 1) *
                                                          bs2]).cuda()
            siamese_label = torch.index_select(sub_label_pool,
                                               dim=0,
                                               index=perm_idx[i * bs2:(i + 1) *
                                                              bs2]).cuda()
            part_xyz = torch.cat([part_xyz1, part_xyz2], -1)
            part_xyz -= torch.mean(part_xyz, -1).unsqueeze(-1)
            part_xyz1 -= torch.mean(part_xyz1, -1).unsqueeze(-1)
            part_xyz2 -= torch.mean(part_xyz2, -1).unsqueeze(-1)
            part_xyz1 /= part_xyz1.norm(dim=1).max(
                dim=-1)[0].unsqueeze(-1).unsqueeze(-1)
            part_xyz2 /= part_xyz2.norm(dim=1).max(
                dim=-1)[0].unsqueeze(-1).unsqueeze(-1)
            part_xyz /= part_xyz.norm(dim=1).max(
                dim=-1)[0].unsqueeze(-1).unsqueeze(-1)
            logits1 = model_merge(part_xyz1, 'backbone')
            logits2 = model_merge(part_xyz2, 'backbone')
            context_xyz = torch.index_select(sub_context_context_xyz_pool,
                                             dim=0,
                                             index=perm_idx[i * bs2:(i + 1) *
                                                            bs2]).cuda()
            context_logits = model_merge(context_xyz, 'backbone2')
            merge_logits = model_merge(
                torch.cat([
                    part_xyz,
                    torch.cat([
                        logits1.detach().unsqueeze(-1).expand(
                            -1, -1, part_xyz1.shape[-1]),
                        logits2.detach().unsqueeze(-1).expand(
                            -1, -1, part_xyz2.shape[-1])
                    ],
                              dim=-1),
                    torch.cat([
                        context_logits.unsqueeze(-1).expand(
                            -1, -1, part_xyz.shape[-1])
                    ],
                              dim=-1)
                ],
                          dim=1), 'head2')
            _, p = torch.max(merge_logits, 1)
            logits1_all = torch.cat([logits1_all, p], dim=0)
            merge_acc_arr = (p == siamese_label.long()).float()
            meters.update(meters_acc_context=torch.mean(merge_acc_arr))
            if torch.sum(siamese_label) != 0:
                merge_pos_acc = torch.mean(
                    torch.index_select(
                        merge_acc_arr,
                        dim=0,
                        index=(siamese_label == 1).nonzero().squeeze()))
                meters.update(merge_pos_acc_context=merge_pos_acc)
            if torch.sum(1 - siamese_label) != 0:
                merge_neg_acc = torch.mean(
                    torch.index_select(
                        merge_acc_arr,
                        dim=0,
                        index=(siamese_label == 0).nonzero().squeeze()))
                meters.update(merge_neg_acc_context=merge_neg_acc)

            loss_sim = cross_entropy(merge_logits, siamese_label.long())

            loss_dict_embed = {
                'loss_sim_context': loss_sim,
            }
            meters.update(**loss_dict_embed)
            total_loss_embed = sum(loss_dict_embed.values())
            total_loss_embed.backward()
            optimizer_embed.step()

        #train purity network
        cur_len = purity_purity_pool.shape[0]
        cur_train_len = TRAIN_LEN if cur_len > TRAIN_LEN else cur_len
        perm_idx = torch.randperm(cur_len)
        sub_purity_pool = torch.index_select(purity_purity_pool,
                                             dim=0,
                                             index=perm_idx[:cur_train_len])
        sub_purity_xyz_pool = torch.index_select(
            purity_xyz_pool, dim=0, index=perm_idx[:cur_train_len])
        perm_idx = torch.arange(cur_train_len)
        for i in range(int(cur_train_len / bs2)):
            optimizer_embed.zero_grad()
            part_xyz = torch.index_select(sub_purity_xyz_pool,
                                          dim=0,
                                          index=perm_idx[i * bs2:(i + 1) *
                                                         bs2]).cuda()
            logits_purity = model_merge(part_xyz, 'purity')
            siamese_label_l2 = torch.index_select(
                sub_purity_pool, dim=0,
                index=perm_idx[i * bs2:(i + 1) * bs2]).cuda()
            loss_purity = l2_loss(logits_purity.squeeze(), siamese_label_l2)
            loss_dict_embed = {
                'loss_purity2': loss_purity,
            }
            meters.update(**loss_dict_embed)
            total_loss_embed = sum(loss_dict_embed.values())
            total_loss_embed.backward()
            optimizer_embed.step()

        #train policy network
        cur_len = policy_xyz_pool1.shape[0]
        cur_train_len = TRAIN_LEN_policy if cur_len > TRAIN_LEN_policy else cur_len
        perm_idx = torch.randperm(cur_len)
        logits1_all = torch.zeros([0]).type(torch.LongTensor).cuda()
        sub_xyz_pool1 = torch.index_select(policy_xyz_pool1,
                                           dim=0,
                                           index=perm_idx[:cur_train_len])
        sub_xyz_pool2 = torch.index_select(policy_xyz_pool2,
                                           dim=0,
                                           index=perm_idx[:cur_train_len])
        sub_purity_pool = torch.index_select(policy_purity_pool,
                                             dim=0,
                                             index=perm_idx[:cur_train_len])
        sub_reward_pool = torch.index_select(policy_reward_pool,
                                             dim=0,
                                             index=perm_idx[:cur_train_len])
        perm_idx = torch.arange(cur_train_len)
        for i in range(int(cur_train_len / bs_policy)):
            optimizer_embed.zero_grad()
            part_xyz1 = torch.index_select(
                sub_xyz_pool1,
                dim=0,
                index=perm_idx[i * bs_policy:(i + 1) * bs_policy]).cuda()
            part_xyz2 = torch.index_select(
                sub_xyz_pool2,
                dim=0,
                index=perm_idx[i * bs_policy:(i + 1) * bs_policy]).cuda()
            purity_arr = torch.index_select(
                sub_purity_pool,
                dim=0,
                index=perm_idx[i * bs_policy:(i + 1) * bs_policy]).cuda()
            reward_arr = torch.index_select(
                sub_reward_pool,
                dim=0,
                index=perm_idx[i * bs_policy:(i + 1) * bs_policy]).cuda()
            logits11 = model_merge(
                part_xyz1.reshape([bs_policy * BS, 3, 1024]), 'policy')
            logits22 = model_merge(
                part_xyz2.reshape([bs_policy * BS, 3, 1024]), 'policy')
            policy_arr = model_merge(torch.cat([logits11, logits22], dim=-1),
                                     'policy_head').squeeze()
            policy_arr = policy_arr.reshape([bs_policy, BS])
            score_arr = softmax(policy_arr * purity_arr)
            loss_policy = torch.mean(-torch.sum(score_arr * reward_arr, dim=1))
            meters.update(loss_policy=loss_policy)
            loss_policy.backward()
            optimizer_embed.step()

        if max_grad_norm > 0:
            nn.utils.clip_grad_norm_(model.parameters(),
                                     max_norm=max_grad_norm)

    train_time = time.time() - end
    meters.update(train_time=train_time)

    logger.info(
        meters.delimiter.join([
            '{meters}',
            'lr_embed: {lr_embed:.4e}',
        ]).format(
            meters=str(meters),
            lr_embed=optimizer_embed.param_groups[0]['lr'],
        ))
    meters.update(lr_embed=optimizer_embed.param_groups[0]['lr'])
    return meters
Пример #11
0
def train_one_epoch(model,
                    loss_fn,
                    metric,
                    dataloader,
                    optimizer,
                    max_grad_norm=0.0,
                    freezer=None,
                    log_period=-1):
    logger = logging.getLogger('shaper.train')
    meters = MetricLogger(delimiter='  ')
    metric.reset()
    meters.bind(metric)
    model.train()
    if freezer is not None:
        freezer.freeze()
    loss_fn.train()

    end = time.time()
    for iteration, data_batch in enumerate(dataloader):
        data_time = time.time() - end

        data_batch = {
            k: v.cuda(non_blocking=True)
            for k, v in data_batch.items()
        }

        preds = model(data_batch)

        optimizer.zero_grad()
        loss_dict = loss_fn(preds, data_batch)
        total_loss = sum(loss_dict.values())

        meters.update(loss=total_loss, **loss_dict)
        meters.update(node_acc=preds['node_acc'],
                      node_pos_acc=preds['node_pos_acc'],
                      node_neg_acc=preds['node_neg_acc'],
                      center_valid_ratio=preds['center_valid_ratio'])
        with torch.no_grad():
            # TODO add loss_dict hack
            metric.update_dict(preds, data_batch)

        total_loss.backward()
        if max_grad_norm > 0:
            nn.utils.clip_grad_norm_(model.parameters(),
                                     max_norm=max_grad_norm)
        optimizer.step()

        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time, data=data_time)

        if log_period > 0 and iteration % log_period == 0:
            logger.info(
                meters.delimiter.join([
                    'iter: {iter:4d}',
                    '{meters}',
                    'lr: {lr:.2e}',
                    'max mem: {memory:.0f}',
                ]).format(
                    iter=iteration,
                    meters=str(meters),
                    lr=optimizer.param_groups[0]['lr'],
                    memory=torch.cuda.max_memory_allocated() / (1024.0**2),
                ))
    return meters
Пример #12
0
def semantic_dist_init(cfg):
    logger = logging.getLogger("semantic_dist_init.trainer")

    _, backbone_name = cfg.MODEL.NAME.split('_')
    feature_num = 2048 if backbone_name.startswith('resnet') else 1024
    feat_estimator = semantic_dist_estimator(feature_num=feature_num, cfg=cfg)
    out_estimator = semantic_dist_estimator(feature_num=cfg.MODEL.NUM_CLASSES,
                                            cfg=cfg)

    feature_extractor = build_feature_extractor(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    feature_extractor.to(device)

    classifier = build_classifier(cfg)
    classifier.to(device)

    torch.cuda.empty_cache()

    # load checkpoint
    if cfg.resume:
        logger.info("Loading checkpoint from {}".format(cfg.resume))
        checkpoint = torch.load(cfg.resume, map_location=torch.device('cpu'))
        feature_extractor_weights = strip_prefix_if_present(
            checkpoint['feature_extractor'], 'module.')
        feature_extractor.load_state_dict(feature_extractor_weights)
        classifier_weights = strip_prefix_if_present(checkpoint['classifier'],
                                                     'module.')
        classifier.load_state_dict(classifier_weights)

    src_train_data = build_dataset(cfg,
                                   mode='train',
                                   is_source=True,
                                   epochwise=True)
    src_train_loader = torch.utils.data.DataLoader(
        src_train_data,
        batch_size=cfg.SOLVER.BATCH_SIZE_VAL,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
        drop_last=False)
    iteration = 0
    feature_extractor.eval()
    classifier.eval()
    end = time.time()
    start_time = time.time()
    max_iters = len(src_train_loader)
    meters = MetricLogger(delimiter="  ")
    logger.info(
        ">>>>>>>>>>>>>>>> Initialize semantic distributions >>>>>>>>>>>>>>>>")
    logger.info(max_iters)
    with torch.no_grad():
        for i, (src_input, src_label, _) in enumerate(src_train_loader):
            data_time = time.time() - end

            src_input = src_input.cuda(non_blocking=True)
            src_label = src_label.cuda(non_blocking=True).long()

            src_feat = feature_extractor(src_input)
            src_out = classifier(src_feat)
            B, N, Hs, Ws = src_feat.size()
            _, C, _, _ = src_out.size()

            # source mask: downsample the ground-truth label
            src_mask = F.interpolate(src_label.unsqueeze(0).float(),
                                     size=(Hs, Ws),
                                     mode='nearest').squeeze(0).long()
            src_mask = src_mask.contiguous().view(B * Hs * Ws, )

            # feature level
            src_feat = src_feat.permute(0, 2, 3,
                                        1).contiguous().view(B * Hs * Ws, N)
            feat_estimator.update(features=src_feat.detach().clone(),
                                  labels=src_mask)

            # output level
            src_out = src_out.permute(0, 2, 3,
                                      1).contiguous().view(B * Hs * Ws, C)
            out_estimator.update(features=src_out.detach().clone(),
                                 labels=src_mask)

            batch_time = time.time() - end
            end = time.time()
            meters.update(time=batch_time, data=data_time)

            iteration = iteration + 1
            eta_seconds = meters.time.global_avg * (max_iters - iteration)
            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

            if iteration % 20 == 0 or iteration == max_iters:
                logger.info(
                    meters.delimiter.join([
                        "eta: {eta}", "iter: {iter}", "{meters}",
                        "max mem: {memory:.02f}"
                    ]).format(eta=eta_string,
                              iter=iteration,
                              meters=str(meters),
                              memory=torch.cuda.max_memory_allocated() /
                              1024.0 / 1024.0 / 1024.0))

            if iteration == max_iters:
                feat_estimator.save(name='feat_dist.pth')
                out_estimator.save(name='out_dist.pth')

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=total_time))
    logger.info("Total training time: {} ({:.4f} s / it)".format(
        total_time_str, total_time / max_iters))
class Trainer(object):
    def __init__(self, args, logger):
        self.args = args
        self.logger = logger
        if get_rank() == 0:
            TBWriter.init(
                os.path.join(args.project_dir, args.task_dir, "tbevents")
            )
        self.device = torch.device(args.device)

        self.meters = MetricLogger(delimiter="  ")
        # image transform
        input_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(
                    [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
                ),
            ]
        )
        # dataset and dataloader
        data_kwargs = {
            "transform": input_transform,
            "base_size": args.base_size,
            "crop_size": args.crop_size,
            "root": args.dataroot,
        }
        train_dataset = get_segmentation_dataset(
            args.dataset, split="train", mode="train", **data_kwargs
        )
        val_dataset = get_segmentation_dataset(
            args.dataset, split="val", mode="val", **data_kwargs
        )
        args.iters_per_epoch = len(train_dataset) // (
            args.num_gpus * args.batch_size
        )
        args.max_iters = args.epochs * args.iters_per_epoch

        train_sampler = make_data_sampler(
            train_dataset, shuffle=True, distributed=args.distributed
        )
        train_batch_sampler = make_batch_data_sampler(
            train_sampler, args.batch_size, args.max_iters
        )
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(
            val_sampler, args.batch_size
        )

        self.train_loader = data.DataLoader(
            dataset=train_dataset,
            batch_sampler=train_batch_sampler,
            num_workers=args.workers,
            pin_memory=True,
        )
        self.val_loader = data.DataLoader(
            dataset=val_dataset,
            batch_sampler=val_batch_sampler,
            num_workers=args.workers,
            pin_memory=True,
        )

        # create network
        BatchNorm2d = nn.SyncBatchNorm if args.distributed else nn.BatchNorm2d
        self.model = get_segmentation_model(
            model=args.model,
            dataset=args.dataset,
            backbone=args.backbone,
            aux=args.aux,
            jpu=args.jpu,
            norm_layer=BatchNorm2d,
        ).to(self.device)

        # resume checkpoint if needed
        if args.resume:
            if os.path.isfile(args.resume):
                name, ext = os.path.splitext(args.resume)
                assert (
                    ext == ".pkl" or ".pth"
                ), "Sorry only .pth and .pkl files supported."
                print("Resuming training, loading {}...".format(args.resume))
                self.model.load_state_dict(
                    torch.load(
                        args.resume, map_location=lambda storage, loc: storage
                    )
                )

        # create criterion
        self.criterion = get_segmentation_loss(
            args.model,
            use_ohem=args.use_ohem,
            aux=args.aux,
            aux_weight=args.aux_weight,
            ignore_index=-1,
        ).to(self.device)

        # optimizer, for model just includes pretrained, head and auxlayer
        params_list = list()
        if hasattr(self.model, "pretrained"):
            params_list.append(
                {"params": self.model.pretrained.parameters(), "lr": args.lr}
            )
        if hasattr(self.model, "exclusive"):
            for module in self.model.exclusive:
                params_list.append(
                    {
                        "params": getattr(self.model, module).parameters(),
                        "lr": args.lr * args.lr_scale,
                    }
                )
        self.optimizer = torch.optim.SGD(
            params_list,
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
        )

        # lr scheduling
        self.lr_scheduler = get_lr_scheduler(self.optimizer, args)
        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[args.local_rank],
                output_device=args.local_rank,
            )

        # evaluation metrics
        self.metric = SegmentationMetric(train_dataset.num_class)

        self.best_pred = 0.0

    def train(self):
        save_to_disk = get_rank() == 0
        epochs, max_iters = self.args.epochs, self.args.max_iters
        log_per_iters, val_per_iters = (
            self.args.log_iter,
            self.args.val_epoch * self.args.iters_per_epoch,
        )
        save_per_iters = self.args.save_epoch * self.args.iters_per_epoch
        start_time = time.time()
        self.logger.info(
            "Start training, Total Epochs: {:d} = Total Iterations {:d}".format(
                epochs, max_iters
            )
        )

        self.model.train()
        end = time.time()
        for iteration, (images, targets, _) in enumerate(self.train_loader):
            iteration = iteration + 1
            self.lr_scheduler.step()
            data_time = time.time() - end

            images = images.to(self.device)
            targets = targets.to(self.device)

            outputs = self.model(images)
            loss_dict = self.criterion(outputs, targets)

            losses = sum(loss for loss in loss_dict.values())

            # reduce losses over all GPUs for logging purposes
            loss_dict_reduced = reduce_loss_dict(loss_dict)
            losses_reduced = sum(loss for loss in loss_dict_reduced.values())

            self.optimizer.zero_grad()
            losses.backward()
            self.optimizer.step()
            batch_time = time.time() - end
            end = time.time()
            self.meters.update(
                data_time=data_time, batch_time=batch_time, loss=losses_reduced
            )

            eta_seconds = ((time.time() - start_time) / iteration) * (
                max_iters - iteration
            )
            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

            if iteration % log_per_iters == 0 and save_to_disk:
                self.logger.info(
                    self.meters.delimiter.join(
                        [
                            "eta: {eta}",
                            "iter: {iter}",
                            "{meters}",
                            "lr: {lr:.6f}",
                            "max mem: {memory:.0f}",
                        ]
                    ).format(
                        eta=eta_string,
                        iter=iteration,
                        meters=(self.meters),
                        lr=self.optimizer.param_groups[0]["lr"],
                        memory=torch.cuda.max_memory_allocated()
                        / 1024.0
                        / 1024.0,
                    )
                )
                if is_main_process():
                    # write train loss and lr
                    TBWriter.write_scalar(
                        ["train/loss", "train/lr", "train/mem"],
                        [
                            losses_reduced,
                            self.optimizer.param_groups[0]["lr"],
                            torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                        ],
                        iter=iteration,
                    )
                    # write time
                    TBWriter.write_scalars(
                        ["train/time"],
                        [self.meters.get_metric(["data_time", "batch_time"])],
                        iter=iteration,
                    )

            if iteration % save_per_iters == 0 and save_to_disk:
                save_checkpoint(self.model, self.args, is_best=False)

            if not self.args.skip_val and iteration % val_per_iters == 0:
                pixAcc, mIoU = self.validation()
                reduced_pixAcc = reduce_tensor(pixAcc)
                reduced_mIoU = reduce_tensor(mIoU)
                new_pred = (reduced_pixAcc + reduced_mIoU) / 2
                new_pred = float(new_pred.cpu().numpy())

                if new_pred > self.best_pred:
                    is_best = True
                    self.best_pred = new_pred

                if is_main_process():
                    TBWriter.write_scalar(
                        ["val/PixelACC", "val/mIoU"],
                        [
                            reduced_pixAcc.cpu().numpy(),
                            reduced_mIoU.cpu().numpy(),
                        ],
                        iter=iteration,
                    )
                    save_checkpoint(self.model, self.args, is_best)
                synchronize()
                self.model.train()

        if is_main_process():
            save_checkpoint(self.model, self.args, is_best=False)
        total_training_time = time.time() - start_time
        total_training_str = str(
            datetime.timedelta(seconds=total_training_time)
        )
        self.logger.info(
            "Total training time: {} ({:.4f}s / it)".format(
                total_training_str, total_training_time / max_iters
            )
        )

    def validation(self):
        # total_inter, total_union, total_correct, total_label = 0, 0, 0, 0
        is_best = False
        self.metric.reset()
        if self.args.distributed:
            model = self.model.module
        else:
            model = self.model
        torch.cuda.empty_cache()  # TODO check if it helps
        model.eval()
        for i, (image, target, filename) in enumerate(self.val_loader):
            image = image.to(self.device)
            target = target.to(self.device)

            with torch.no_grad():
                outputs = model(image)
            self.metric.update(outputs[0], target)
            # pixAcc, mIoU = self.metric.get()
            # logger.info(
            # "Sample: {:d}, Validation pixAcc: {:.3f}, mIoU: {:.3f}".format(
            # i + 1, pixAcc, mIoU
            # )
            # )
        pixAcc, mIoU = self.metric.get()

        return (
            torch.tensor(pixAcc).to(self.device),
            torch.tensor(mIoU).to(self.device),
        )
Пример #14
0
def test(cfg, output_dir=''):
    logger = logging.getLogger('shaper.test')

    # build model
    model, loss_fn, metric = build_model(cfg)
    model = nn.DataParallel(model).cuda()
    # model = model.cuda()

    # build checkpointer
    checkpointer = Checkpointer(model, save_dir=output_dir, logger=logger)

    if cfg.TEST.WEIGHT:
        # load weight if specified
        weight_path = cfg.TEST.WEIGHT.replace('@', output_dir)
        checkpointer.load(weight_path, resume=False)
    else:
        # load last checkpoint
        checkpointer.load(None, resume=True)

    # build data loader
    test_dataloader = build_dataloader(cfg, mode='test')
    test_dataset = test_dataloader.dataset

    # ---------------------------------------------------------------------------- #
    # Test
    # ---------------------------------------------------------------------------- #
    model.eval()
    loss_fn.eval()
    metric.eval()
    set_random_seed(cfg.RNG_SEED)
    evaluator = Evaluator(test_dataset.class_names)

    if cfg.TEST.VOTE.NUM_VOTE > 1:
        # remove old transform
        test_dataset.transform = None
        if cfg.TEST.VOTE.TYPE == 'AUGMENTATION':
            tmp_cfg = cfg.clone()
            tmp_cfg.defrost()
            tmp_cfg.TEST.AUGMENTATION = tmp_cfg.TEST.VOTE.AUGMENTATION
            transform = T.Compose([T.ToTensor()] +
                                  parse_augmentations(tmp_cfg, False) +
                                  [T.Transpose()])
            transform_list = [transform] * cfg.TEST.VOTE.NUM_VOTE
        elif cfg.TEST.VOTE.TYPE == 'MULTI_VIEW':
            # build new transform
            transform_list = []
            for view_ind in range(cfg.TEST.VOTE.NUM_VOTE):
                aug_type = T.RotateByAngleWithNormal if cfg.INPUT.USE_NORMAL else T.RotateByAngle
                rotate_by_angle = aug_type(
                    cfg.TEST.VOTE.MULTI_VIEW.AXIS,
                    2 * np.pi * view_ind / cfg.TEST.VOTE.NUM_VOTE)
                t = [T.ToTensor(), rotate_by_angle, T.Transpose()]
                if cfg.TEST.VOTE.MULTI_VIEW.SHUFFLE:
                    # Some non-deterministic algorithms, like PointNet++, benefit from shuffle.
                    t.insert(-1, T.Shuffle())
                transform_list.append(T.Compose(t))
        else:
            raise NotImplementedError('Unsupported voting method.')

        with torch.no_grad():
            tmp_dataloader = DataLoader(test_dataset,
                                        num_workers=1,
                                        collate_fn=lambda x: x[0])
            start_time = time.time()
            end = start_time
            for ind, data in enumerate(tmp_dataloader):
                data_time = time.time() - end
                points = data['points']

                # convert points into tensor
                points_batch = [t(points.copy()) for t in transform_list]
                points_batch = torch.stack(points_batch, dim=0)
                points_batch = points_batch.cuda(non_blocking=True)

                preds = model({'points': points_batch})
                cls_logit_batch = preds['cls_logit'].cpu().numpy(
                )  # (batch_size, num_classes)
                cls_logit_ensemble = np.mean(cls_logit_batch, axis=0)
                pred_label = np.argmax(cls_logit_ensemble)
                evaluator.update(pred_label, data['cls_label'])

                batch_time = time.time() - end
                end = time.time()

                if cfg.TEST.LOG_PERIOD > 0 and ind % cfg.TEST.LOG_PERIOD == 0:
                    logger.info('iter: {:4d}  time:{:.4f}  data:{:.4f}'.format(
                        ind, batch_time, data_time))
        test_time = time.time() - start_time
        logger.info('Test total time: {:.2f}s'.format(test_time))
    else:
        test_meters = MetricLogger(delimiter='  ')
        test_meters.bind(metric)
        with torch.no_grad():
            start_time = time.time()
            end = start_time
            for iteration, data_batch in enumerate(test_dataloader):
                data_time = time.time() - end

                cls_label_batch = data_batch['cls_label'].numpy()
                data_batch = {
                    k: v.cuda(non_blocking=True)
                    for k, v in data_batch.items()
                }

                preds = model(data_batch)

                loss_dict = loss_fn(preds, data_batch)
                total_loss = sum(loss_dict.values())

                test_meters.update(loss=total_loss, **loss_dict)
                metric.update_dict(preds, data_batch)

                cls_logit_batch = preds['cls_logit'].cpu().numpy(
                )  # (batch_size, num_classes)
                pred_label_batch = np.argmax(cls_logit_batch, axis=1)
                evaluator.batch_update(pred_label_batch, cls_label_batch)

                batch_time = time.time() - end
                end = time.time()
                test_meters.update(time=batch_time, data=data_time)

                if cfg.TEST.LOG_PERIOD > 0 and iteration % cfg.TEST.LOG_PERIOD == 0:
                    logger.info(
                        test_meters.delimiter.join([
                            'iter: {iter:4d}',
                            '{meters}',
                        ]).format(
                            iter=iteration,
                            meters=str(test_meters),
                        ))
        test_time = time.time() - start_time
        logger.info('Test {}  total time: {:.2f}s'.format(
            test_meters.summary_str, test_time))

    # evaluate
    logger.info('overall accuracy={:.2f}%'.format(100.0 *
                                                  evaluator.overall_accuracy))
    logger.info('average class accuracy={:.2f}%.\n{}'.format(
        100.0 * np.nanmean(evaluator.class_accuracy), evaluator.print_table()))
    evaluator.save_table(osp.join(output_dir, 'eval.cls.tsv'))
Пример #15
0
def train(cfg, local_rank, distributed):
    logger = logging.getLogger("FADA.trainer")
    logger.info("Start training")

    feature_extractor = build_feature_extractor(cfg, adv=True)
    device = torch.device(cfg.MODEL.DEVICE)
    feature_extractor.to(device)

    classifier = build_classifier(cfg)
    classifier.to(device)

    model_D = build_adversarial_discriminator(cfg)
    model_D.to(device)

    if local_rank == 0:
        print(feature_extractor)
        print(model_D)

    batch_size = cfg.SOLVER.BATCH_SIZE // 2
    if distributed:
        pg1 = torch.distributed.new_group(
            range(torch.distributed.get_world_size()))

        batch_size = int(
            cfg.SOLVER.BATCH_SIZE / torch.distributed.get_world_size()) // 2
        if not cfg.MODEL.FREEZE_BN:
            feature_extractor = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                feature_extractor)
        feature_extractor = torch.nn.parallel.DistributedDataParallel(
            feature_extractor,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True,
            process_group=pg1)
        pg2 = torch.distributed.new_group(
            range(torch.distributed.get_world_size()))
        classifier = torch.nn.parallel.DistributedDataParallel(
            classifier,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True,
            process_group=pg2)
        pg3 = torch.distributed.new_group(
            range(torch.distributed.get_world_size()))
        model_D = torch.nn.parallel.DistributedDataParallel(
            model_D,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True,
            process_group=pg3)
        torch.autograd.set_detect_anomaly(True)
        torch.distributed.barrier()

    optimizer_fea = torch.optim.SGD(feature_extractor.parameters(),
                                    lr=cfg.SOLVER.BASE_LR,
                                    momentum=cfg.SOLVER.MOMENTUM,
                                    weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    optimizer_fea.zero_grad()

    optimizer_cls = torch.optim.SGD(classifier.parameters(),
                                    lr=cfg.SOLVER.BASE_LR * 10,
                                    momentum=cfg.SOLVER.MOMENTUM,
                                    weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    optimizer_cls.zero_grad()

    optimizer_D = torch.optim.Adam(model_D.parameters(),
                                   lr=cfg.SOLVER.BASE_LR_D,
                                   betas=(0.9, 0.99))
    optimizer_D.zero_grad()

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = local_rank == 0

    start_epoch = 0
    iteration = 0

    if cfg.resume:
        logger.info("Loading checkpoint from {}".format(cfg.resume))
        checkpoint = torch.load(cfg.resume, map_location=torch.device('cpu'))
        model_weights = checkpoint[
            'feature_extractor'] if distributed else strip_prefix_if_present(
                checkpoint['feature_extractor'], 'module.')
        feature_extractor.load_state_dict(model_weights)
        classifier_weights = checkpoint[
            'classifier'] if distributed else strip_prefix_if_present(
                checkpoint['classifier'], 'module.')
        classifier.load_state_dict(classifier_weights)
        if "model_D" in checkpoint:
            logger.info("Loading model_D from {}".format(cfg.resume))
            model_D_weights = checkpoint[
                'model_D'] if distributed else strip_prefix_if_present(
                    checkpoint['model_D'], 'module.')
            model_D.load_state_dict(model_D_weights)
        # if "optimizer_fea" in checkpoint:
        #     logger.info("Loading optimizer_fea from {}".format(cfg.resume))
        #     optimizer_fea.load_state_dict(checkpoint['optimizer_fea'])
        # if "optimizer_cls" in checkpoint:
        #     logger.info("Loading optimizer_cls from {}".format(cfg.resume))
        #     optimizer_cls.load_state_dict(checkpoint['optimizer_cls'])
        # if "optimizer_D" in checkpoint:
        #     logger.info("Loading optimizer_D from {}".format(cfg.resume))
        #     optimizer_D.load_state_dict(checkpoint['optimizer_D'])
        # if "iteration" in checkpoint:
        #     iteration = checkpoint['iteration']

    src_train_data = build_dataset(cfg, mode='train', is_source=True)
    tgt_train_data = build_dataset(cfg, mode='train', is_source=False)

    if distributed:
        src_train_sampler = torch.utils.data.distributed.DistributedSampler(
            src_train_data)
        tgt_train_sampler = torch.utils.data.distributed.DistributedSampler(
            tgt_train_data)
    else:
        src_train_sampler = None
        tgt_train_sampler = None

    src_train_loader = torch.utils.data.DataLoader(
        src_train_data,
        batch_size=batch_size,
        shuffle=(src_train_sampler is None),
        num_workers=4,
        pin_memory=True,
        sampler=src_train_sampler,
        drop_last=True)
    tgt_train_loader = torch.utils.data.DataLoader(
        tgt_train_data,
        batch_size=batch_size,
        shuffle=(tgt_train_sampler is None),
        num_workers=4,
        pin_memory=True,
        sampler=tgt_train_sampler,
        drop_last=True)

    criterion = torch.nn.CrossEntropyLoss(ignore_index=255)
    bce_loss = torch.nn.BCELoss(reduction='none')

    max_iters = cfg.SOLVER.MAX_ITER
    source_label = 0
    target_label = 1
    logger.info("Start training")
    meters = MetricLogger(delimiter="  ")
    feature_extractor.train()
    classifier.train()
    model_D.train()
    start_training_time = time.time()
    end = time.time()
    for i, ((src_input, src_label, src_name),
            (tgt_input, _,
             _)) in enumerate(zip(src_train_loader, tgt_train_loader)):
        #             torch.distributed.barrier()
        data_time = time.time() - end

        current_lr = adjust_learning_rate(cfg.SOLVER.LR_METHOD,
                                          cfg.SOLVER.BASE_LR,
                                          iteration,
                                          max_iters,
                                          power=cfg.SOLVER.LR_POWER)
        current_lr_D = adjust_learning_rate(cfg.SOLVER.LR_METHOD,
                                            cfg.SOLVER.BASE_LR_D,
                                            iteration,
                                            max_iters,
                                            power=cfg.SOLVER.LR_POWER)
        for index in range(len(optimizer_fea.param_groups)):
            optimizer_fea.param_groups[index]['lr'] = current_lr
        for index in range(len(optimizer_cls.param_groups)):
            optimizer_cls.param_groups[index]['lr'] = current_lr * 10
        for index in range(len(optimizer_D.param_groups)):
            optimizer_D.param_groups[index]['lr'] = current_lr_D


#       torch.distributed.barrier()

        optimizer_fea.zero_grad()
        optimizer_cls.zero_grad()
        optimizer_D.zero_grad()
        src_input = src_input.cuda(non_blocking=True)
        src_label = src_label.cuda(non_blocking=True).long()
        tgt_input = tgt_input.cuda(non_blocking=True)

        src_size = src_input.shape[-2:]
        tgt_size = tgt_input.shape[-2:]
        inp = torch.cat(
            [src_input, F.interpolate(tgt_input.detach(), src_size)])
        # try:
        src_fea = feature_extractor(inp)[:batch_size]

        src_pred = classifier(src_fea, src_size)
        temperature = 1.8
        src_pred = src_pred.div(temperature)
        loss_seg = criterion(src_pred, src_label)
        loss_seg.backward()

        # torch.distributed.barrier()

        # generate soft labels
        src_soft_label = F.softmax(src_pred, dim=1).detach()
        src_soft_label[src_soft_label > 0.9] = 0.9

        tgt_fea = feature_extractor(tgt_input)
        tgt_pred = classifier(tgt_fea, tgt_size)
        tgt_pred = tgt_pred.div(temperature)
        tgt_soft_label = F.softmax(tgt_pred, dim=1)

        tgt_soft_label = tgt_soft_label.detach()
        tgt_soft_label[tgt_soft_label > 0.9] = 0.9

        tgt_D_pred = model_D(tgt_fea, tgt_size)
        loss_adv_tgt = 0.001 * soft_label_cross_entropy(
            tgt_D_pred,
            torch.cat(
                (tgt_soft_label, torch.zeros_like(tgt_soft_label)), dim=1))
        loss_adv_tgt.backward()

        optimizer_fea.step()
        optimizer_cls.step()

        optimizer_D.zero_grad()
        # torch.distributed.barrier()

        src_D_pred = model_D(src_fea.detach(), src_size)
        loss_D_src = 0.5 * soft_label_cross_entropy(
            src_D_pred,
            torch.cat(
                (src_soft_label, torch.zeros_like(src_soft_label)), dim=1))
        loss_D_src.backward()

        tgt_D_pred = model_D(tgt_fea.detach(), tgt_size)
        loss_D_tgt = 0.5 * soft_label_cross_entropy(
            tgt_D_pred,
            torch.cat(
                (torch.zeros_like(tgt_soft_label), tgt_soft_label), dim=1))
        loss_D_tgt.backward()

        # torch.distributed.barrier()

        optimizer_D.step()

        meters.update(loss_seg=loss_seg.item())
        meters.update(loss_adv_tgt=loss_adv_tgt.item())
        meters.update(loss_D=(loss_D_src.item() + loss_D_tgt.item()))
        meters.update(loss_D_src=loss_D_src.item())
        meters.update(loss_D_tgt=loss_D_tgt.item())

        iteration = iteration + 1

        n = src_input.size(0)

        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time, data=data_time)

        eta_seconds = meters.time.global_avg * (cfg.SOLVER.STOP_ITER -
                                                iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

        if iteration % 20 == 0 or iteration == max_iters:
            logger.info(
                meters.delimiter.join([
                    "eta: {eta}",
                    "iter: {iter}",
                    "{meters}",
                    "lr: {lr:.6f}",
                    "max mem: {memory:.0f}",
                ]).format(
                    eta=eta_string,
                    iter=iteration,
                    meters=str(meters),
                    lr=optimizer_fea.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                ))

        if (iteration == cfg.SOLVER.MAX_ITER or iteration %
                cfg.SOLVER.CHECKPOINT_PERIOD == 0) and save_to_disk:
            filename = os.path.join(output_dir,
                                    "model_iter{:06d}.pth".format(iteration))
            torch.save(
                {
                    'iteration': iteration,
                    'feature_extractor': feature_extractor.state_dict(),
                    'classifier': classifier.state_dict(),
                    'model_D': model_D.state_dict(),
                    'optimizer_fea': optimizer_fea.state_dict(),
                    'optimizer_cls': optimizer_cls.state_dict(),
                    'optimizer_D': optimizer_D.state_dict()
                }, filename)

        if iteration == cfg.SOLVER.MAX_ITER:
            break
        if iteration == cfg.SOLVER.STOP_ITER:
            break

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info("Total training time: {} ({:.4f} s / it)".format(
        total_time_str, total_training_time / (cfg.SOLVER.MAX_ITER)))

    return feature_extractor, classifier
Пример #16
0
def train(cfg, local_rank, distributed):
    logger = logging.getLogger("BCDM.trainer")
    logger.info("Start training")

    feature_extractor = build_feature_extractor(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    feature_extractor.to(device)

    classifier = build_classifier(cfg)
    classifier.to(device)

    classifier_2 = build_classifier(cfg)
    classifier_2.to(device)

    if local_rank == 0:
        print(feature_extractor)

    batch_size = cfg.SOLVER.BATCH_SIZE // 2
    if distributed:
        pg1 = torch.distributed.new_group(
            range(torch.distributed.get_world_size()))

        batch_size = int(
            cfg.SOLVER.BATCH_SIZE / torch.distributed.get_world_size()) // 2
        if not cfg.MODEL.FREEZE_BN:
            feature_extractor = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                feature_extractor)
        feature_extractor = torch.nn.parallel.DistributedDataParallel(
            feature_extractor,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True,
            process_group=pg1)
        pg2 = torch.distributed.new_group(
            range(torch.distributed.get_world_size()))
        classifier = torch.nn.parallel.DistributedDataParallel(
            classifier,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True,
            process_group=pg2)
        pg3 = torch.distributed.new_group(
            range(torch.distributed.get_world_size()))
        classifier_2 = torch.nn.parallel.DistributedDataParallel(
            classifier_2,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True,
            process_group=pg3)

        torch.autograd.set_detect_anomaly(True)
        torch.distributed.barrier()

    optimizer_fea = torch.optim.SGD(feature_extractor.parameters(),
                                    lr=cfg.SOLVER.BASE_LR,
                                    momentum=cfg.SOLVER.MOMENTUM,
                                    weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    optimizer_fea.zero_grad()

    optimizer_cls = torch.optim.SGD(list(classifier.parameters()) +
                                    list(classifier_2.parameters()),
                                    lr=cfg.SOLVER.BASE_LR * 10,
                                    momentum=cfg.SOLVER.MOMENTUM,
                                    weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    optimizer_cls.zero_grad()

    output_dir = cfg.OUTPUT_DIR

    save_to_disk = local_rank == 0

    iteration = 0

    if cfg.resume:
        logger.info("Loading checkpoint from {}".format(cfg.resume))
        checkpoint = torch.load(cfg.resume, map_location=torch.device('cpu'))
        model_weights = checkpoint[
            'feature_extractor'] if distributed else strip_prefix_if_present(
                checkpoint['feature_extractor'], 'module.')
        feature_extractor.load_state_dict(model_weights)
        classifier_weights = checkpoint[
            'classifier'] if distributed else strip_prefix_if_present(
                checkpoint['classifier'], 'module.')
        classifier.load_state_dict(classifier_weights)
        classifier_2_weights = checkpoint[
            'classifier'] if distributed else strip_prefix_if_present(
                checkpoint['classifier_2'], 'module.')
        classifier_2.load_state_dict(classifier_2_weights)
        # if "optimizer_fea" in checkpoint:
        #     logger.info("Loading optimizer_fea from {}".format(cfg.resume))
        #     optimizer_fea.load_state_dict(checkpoint['optimizer_fea'])
        # if "optimizer_cls" in checkpoint:
        #     logger.info("Loading optimizer_cls from {}".format(cfg.resume))
        #     optimizer_cls.load_state_dict(checkpoint['optimizer_cls'])
        # if "iteration" in checkpoint:
        #     iteration = checkpoint['iteration']

    src_train_data = build_dataset(cfg, mode='train', is_source=True)
    tgt_train_data = build_dataset(cfg, mode='train', is_source=False)

    if distributed:
        src_train_sampler = torch.utils.data.distributed.DistributedSampler(
            src_train_data)
        tgt_train_sampler = torch.utils.data.distributed.DistributedSampler(
            tgt_train_data)
    else:
        src_train_sampler = None
        tgt_train_sampler = None

    src_train_loader = torch.utils.data.DataLoader(
        src_train_data,
        batch_size=batch_size,
        shuffle=(src_train_sampler is None),
        num_workers=4,
        pin_memory=True,
        sampler=src_train_sampler,
        drop_last=True)
    tgt_train_loader = torch.utils.data.DataLoader(
        tgt_train_data,
        batch_size=batch_size,
        shuffle=(tgt_train_sampler is None),
        num_workers=4,
        pin_memory=True,
        sampler=tgt_train_sampler,
        drop_last=True)

    criterion = torch.nn.CrossEntropyLoss(ignore_index=255)

    max_iters = cfg.SOLVER.MAX_ITER
    logger.info("Start training")
    meters = MetricLogger(delimiter="  ")
    feature_extractor.train()
    classifier.train()
    classifier_2.train()
    start_training_time = time.time()
    end = time.time()
    for i, ((src_input, src_label, src_name),
            (tgt_input, _,
             _)) in enumerate(zip(src_train_loader, tgt_train_loader)):
        data_time = time.time() - end

        current_lr = adjust_learning_rate(cfg.SOLVER.LR_METHOD,
                                          cfg.SOLVER.BASE_LR,
                                          iteration,
                                          max_iters,
                                          power=cfg.SOLVER.LR_POWER)
        for index in range(len(optimizer_fea.param_groups)):
            optimizer_fea.param_groups[index]['lr'] = current_lr
        for index in range(len(optimizer_cls.param_groups)):
            optimizer_cls.param_groups[index]['lr'] = current_lr * 10

        # Step A: train on source (CE loss) and target (Ent loss)
        optimizer_fea.zero_grad()
        optimizer_cls.zero_grad()
        src_input = src_input.cuda(non_blocking=True)
        src_label = src_label.cuda(non_blocking=True).long()
        tgt_input = tgt_input.cuda(non_blocking=True)

        src_size = src_input.shape[-2:]
        tgt_size = tgt_input.shape[-2:]

        src_fea = feature_extractor(src_input)
        src_pred = classifier(src_fea, src_size)
        src_pred_2 = classifier_2(src_fea, src_size)
        temperature = 1.8
        src_pred = src_pred.div(temperature)
        src_pred_2 = src_pred_2.div(temperature)
        # source segmentation loss
        loss_seg = criterion(src_pred, src_label) + criterion(
            src_pred_2, src_label)

        tgt_fea = feature_extractor(tgt_input)
        tgt_pred = classifier(tgt_fea, tgt_size)
        tgt_pred_2 = classifier_2(tgt_fea, tgt_size)
        tgt_pred = F.softmax(tgt_pred)
        tgt_pred_2 = F.softmax(tgt_pred_2)
        loss_ent = entropy_loss(tgt_pred) + entropy_loss(tgt_pred_2)
        total_loss = loss_seg + cfg.SOLVER.ENT_LOSS * loss_ent
        total_loss.backward()
        # torch.distributed.barrier()
        optimizer_fea.step()
        optimizer_cls.step()

        # Step B: train bi-classifier to maximize loss_cdd
        optimizer_fea.zero_grad()
        optimizer_cls.zero_grad()
        src_fea = feature_extractor(src_input)
        src_pred = classifier(src_fea, src_size)
        src_pred_2 = classifier_2(src_fea, src_size)
        temperature = 1.8
        src_pred = src_pred.div(temperature)
        src_pred_2 = src_pred_2.div(temperature)
        loss_seg = criterion(src_pred, src_label) + criterion(
            src_pred_2, src_label)

        tgt_fea = feature_extractor(tgt_input)
        tgt_pred = classifier(tgt_fea, tgt_size)
        tgt_pred_2 = classifier_2(tgt_fea, tgt_size)
        tgt_pred = F.softmax(tgt_pred)
        tgt_pred_2 = F.softmax(tgt_pred_2)
        loss_ent = entropy_loss(tgt_pred) + entropy_loss(tgt_pred_2)
        loss_cdd = discrepancy_calc(tgt_pred, tgt_pred_2)
        total_loss = loss_seg - cfg.SOLVER.CDD_LOSS * loss_cdd + cfg.SOLVER.ENT_LOSS * loss_ent
        total_loss.backward()
        optimizer_cls.step()

        # Step C: train feature extractor to min loss_cdd
        for k in range(cfg.SOLVER.NUM_K):
            optimizer_fea.zero_grad()
            optimizer_cls.zero_grad()
            tgt_fea = feature_extractor(tgt_input)
            tgt_pred = classifier(tgt_fea, tgt_size)
            tgt_pred_2 = classifier_2(tgt_fea, tgt_size)
            tgt_pred = F.softmax(tgt_pred)
            tgt_pred_2 = F.softmax(tgt_pred_2)
            loss_ent = entropy_loss(tgt_pred) + entropy_loss(tgt_pred_2)
            loss_cdd = discrepancy_calc(tgt_pred, tgt_pred_2)
            total_loss = cfg.SOLVER.CDD_LOSS * loss_cdd + cfg.SOLVER.ENT_LOSS * loss_ent
            total_loss.backward()
            optimizer_fea.zero_grad()

        meters.update(loss_seg=loss_seg.item())
        meters.update(loss_cdd=loss_cdd.item())
        meters.update(loss_ent=loss_ent.item())

        iteration = iteration + 1

        n = src_input.size(0)

        batch_time = time.time() - end
        end = time.time()
        meters.update(time=batch_time, data=data_time)

        eta_seconds = meters.time.global_avg * (cfg.SOLVER.STOP_ITER -
                                                iteration)
        eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))

        if iteration % 20 == 0 or iteration == max_iters:
            logger.info(
                meters.delimiter.join([
                    "eta: {eta}",
                    "iter: {iter}",
                    "{meters}",
                    "lr: {lr:.6f}",
                    "max mem: {memory:.0f}",
                ]).format(
                    eta=eta_string,
                    iter=iteration,
                    meters=str(meters),
                    lr=optimizer_fea.param_groups[0]["lr"],
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
                ))

        if (iteration == cfg.SOLVER.MAX_ITER or iteration %
                cfg.SOLVER.CHECKPOINT_PERIOD == 0) and save_to_disk:
            filename = os.path.join(output_dir,
                                    "model_iter{:06d}.pth".format(iteration))
            torch.save(
                {
                    'iteration': iteration,
                    'feature_extractor': feature_extractor.state_dict(),
                    'classifier': classifier.state_dict(),
                    'classifier_2': classifier_2.state_dict(),
                    'optimizer_fea': optimizer_fea.state_dict(),
                    'optimizer_cls': optimizer_cls.state_dict()
                }, filename)

        if iteration == cfg.SOLVER.MAX_ITER:
            break
        if iteration == cfg.SOLVER.STOP_ITER:
            break

    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info("Total training time: {} ({:.4f} s / it)".format(
        total_time_str, total_training_time / cfg.SOLVER.MAX_ITER))

    return feature_extractor, classifier, classifier_2
    def __init__(self, args, logger):
        self.args = args
        self.logger = logger
        if get_rank() == 0:
            TBWriter.init(
                os.path.join(args.project_dir, args.task_dir, "tbevents")
            )
        self.device = torch.device(args.device)

        self.meters = MetricLogger(delimiter="  ")
        # image transform
        input_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(
                    [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
                ),
            ]
        )
        # dataset and dataloader
        data_kwargs = {
            "transform": input_transform,
            "base_size": args.base_size,
            "crop_size": args.crop_size,
            "root": args.dataroot,
        }
        train_dataset = get_segmentation_dataset(
            args.dataset, split="train", mode="train", **data_kwargs
        )
        val_dataset = get_segmentation_dataset(
            args.dataset, split="val", mode="val", **data_kwargs
        )
        args.iters_per_epoch = len(train_dataset) // (
            args.num_gpus * args.batch_size
        )
        args.max_iters = args.epochs * args.iters_per_epoch

        train_sampler = make_data_sampler(
            train_dataset, shuffle=True, distributed=args.distributed
        )
        train_batch_sampler = make_batch_data_sampler(
            train_sampler, args.batch_size, args.max_iters
        )
        val_sampler = make_data_sampler(val_dataset, False, args.distributed)
        val_batch_sampler = make_batch_data_sampler(
            val_sampler, args.batch_size
        )

        self.train_loader = data.DataLoader(
            dataset=train_dataset,
            batch_sampler=train_batch_sampler,
            num_workers=args.workers,
            pin_memory=True,
        )
        self.val_loader = data.DataLoader(
            dataset=val_dataset,
            batch_sampler=val_batch_sampler,
            num_workers=args.workers,
            pin_memory=True,
        )

        # create network
        BatchNorm2d = nn.SyncBatchNorm if args.distributed else nn.BatchNorm2d
        self.model = get_segmentation_model(
            model=args.model,
            dataset=args.dataset,
            backbone=args.backbone,
            aux=args.aux,
            jpu=args.jpu,
            norm_layer=BatchNorm2d,
        ).to(self.device)

        # resume checkpoint if needed
        if args.resume:
            if os.path.isfile(args.resume):
                name, ext = os.path.splitext(args.resume)
                assert (
                    ext == ".pkl" or ".pth"
                ), "Sorry only .pth and .pkl files supported."
                print("Resuming training, loading {}...".format(args.resume))
                self.model.load_state_dict(
                    torch.load(
                        args.resume, map_location=lambda storage, loc: storage
                    )
                )

        # create criterion
        self.criterion = get_segmentation_loss(
            args.model,
            use_ohem=args.use_ohem,
            aux=args.aux,
            aux_weight=args.aux_weight,
            ignore_index=-1,
        ).to(self.device)

        # optimizer, for model just includes pretrained, head and auxlayer
        params_list = list()
        if hasattr(self.model, "pretrained"):
            params_list.append(
                {"params": self.model.pretrained.parameters(), "lr": args.lr}
            )
        if hasattr(self.model, "exclusive"):
            for module in self.model.exclusive:
                params_list.append(
                    {
                        "params": getattr(self.model, module).parameters(),
                        "lr": args.lr * args.lr_scale,
                    }
                )
        self.optimizer = torch.optim.SGD(
            params_list,
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay,
        )

        # lr scheduling
        self.lr_scheduler = get_lr_scheduler(self.optimizer, args)
        if args.distributed:
            self.model = nn.parallel.DistributedDataParallel(
                self.model,
                device_ids=[args.local_rank],
                output_device=args.local_rank,
            )

        # evaluation metrics
        self.metric = SegmentationMetric(train_dataset.num_class)

        self.best_pred = 0.0