Esempio n. 1
0
def test_dataloader():
    name = dataset_names()[0]
    dataloader = build_dataloader(name)
    for (img, bboxes) in dataloader:
        img = img[0].cpu().numpy().transpose(1, 2, 0)
        bboxes = bboxes[0]
        plot_bbox(img, bboxes)
Esempio n. 2
0
def test(epochs_tested):
    is_train=False
    transforms = transform.build_transforms(is_train=is_train)
    coco_dataset = dataset.COCODataset(is_train=is_train, transforms=transforms)
    dataloader = build_dataloader(coco_dataset, sampler=None, is_train=is_train)

    assert isinstance(epochs_tested, (list, set)), "during test, archive_name must be a list or set!"
    model = FCOS(is_train=is_train)

    for epoch in epochs_tested:
        utils.load_model(model, epoch)
        model.cuda()
        model.eval()

        final_results = []

        with torch.no_grad():
            for data in tqdm(dataloader):
                img = data["images"]
                ori_img_shape = data["ori_img_shape"]
                fin_img_shape = data["fin_img_shape"]
                index = data["indexs"]

                img = img.cuda()
                ori_img_shape = ori_img_shape.cuda()
                fin_img_shape = fin_img_shape.cuda()

                cls_pred, reg_pred, label_pred = model([img, ori_img_shape, fin_img_shape])

                cls_pred = cls_pred[0].cpu()
                reg_pred = reg_pred[0].cpu()
                label_pred = label_pred[0].cpu()
                index = index[0]

                img_info = dataloader.dataset.img_infos[index]
                imgid = img_info["id"]

                reg_pred = utils.xyxy2xywh(reg_pred)

                label_pred = label_pred.tolist()
                cls_pred = cls_pred.tolist()

                final_results.extend(
                    [
                        {
                            "image_id": imgid,
                            "category_id": dataloader.dataset.label2catid[label_pred[k]],
                            "bbox": reg_pred[k].tolist(),
                            "score": cls_pred[k],
                        }
                        for k in range(len(reg_pred))
                    ]
                )

        output_path = os.path.join(cfg.output_path, "fcos_"+str(epoch)+".json")
        utils.evaluate_coco(dataloader.dataset.coco, final_results, output_path, "bbox")
Esempio n. 3
0
    def __init__(self, opt):
        self.opt = opt
        self.device = self.opt.device

        train_dataloader, valid_dataloader = build_dataloader(opt)
        self.dataloader = {
            'train': train_dataloader,
            'valid': valid_dataloader
        }

        self.net = network(self.opt, self.opt.channels, self.opt.height,
                           self.opt.width)

        self.net.to(self.device)
        self.criterion = nn.CrossEntropyLoss()

        if self.opt.pretrained_model:
            self.load_weight()
        self.optimizer = torch.optim.Adam(self.net.parameters(),
                                          self.opt.learning_rate,
                                          weight_decay=1e-4)
Esempio n. 4
0
def main(name):
    dataloader = build_dataloader(name)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = SORT().to(device)
    model.eval()

    outputs = []
    for (x, y) in tqdm(dataloader):
        x = x.to(device)
        tracker_outs = model(x)

        # plot result
        # plot_bbox(x[0].cpu().numpy().transpose(1, 2, 0), tracker_outs)

        # reformat bboxes from xyxy to xywh
        bboxes = tracker_outs[:, :4]
        bboxes[:, 2] -= bboxes[:, 0]
        bboxes[:, 3] -= bboxes[:, 1]

        # generate final output format
        tracker_ids = tracker_outs[:, [4]]
        frame_id = y[0, 0, -1] * torch.ones_like(tracker_ids)
        dummy = torch.ones_like(bboxes)
        output = torch.cat([frame_id, tracker_ids, bboxes, dummy], dim=1)
        outputs.append(output)

    outputs = torch.cat(outputs, dim=0).cpu().numpy()
    out_dir = './results'
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    np.savetxt(
        os.path.join(out_dir, '%s.txt' % name),
        outputs,
        fmt='%.1f',
        delimiter=',',
    )
Esempio n. 5
0
def train():
    model = RML()
    # model.load_state_dict(torch.load("./rml_1000_final.pth"))
    model.cuda()

    model.eval()
    tr_labeled, tr_unlabeled = build_dataloader("train", model)
    val_loader, _ = build_dataloader("val", None)

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=cfg.SOLVER.BASE_LR,
                                momentum=0.9,
                                weight_decay=cfg.SOLVER.WEIGHT_DECAY)
    criterion = RobustWarpLoss(cfg.DATA.BATCH_SIZE, cfg.DATA.BATCH_SIZE)

    t = tqdm(range(80))
    for epoch in t:  # 80 epochs, hard coded for now
        running_loss = 0.0
        if epoch % 20 == 19:  # reinitialize dummy labels
            model.eval()
            _, tr_unlabeled = build_dataloader("train", model)
        model.train()
        for i, (labeled, unlabeled) in enumerate(zip(tr_labeled,
                                                     tr_unlabeled)):
            labeled_ims, labeled_targets = labeled
            unlabeled_ims, unlabeled_targets, s = unlabeled

            labeled_ims, labeled_targets = labeled_ims.cuda(
            ), labeled_targets.cuda()
            unlabeled_ims, unlabeled_targets = unlabeled_ims.cuda(
            ), unlabeled_targets.cuda()

            optimizer.zero_grad()
            labeled_outputs = model(labeled_ims)
            unlabeled_outputs = model(unlabeled_ims)

            loss = criterion(torch.cat((labeled_outputs, unlabeled_outputs)),
                             torch.cat((labeled_targets, unlabeled_targets)),
                             s)
            loss.backward()

            optimizer.step()

            postfix_dict = {"Loss": loss.item()}
            t.set_postfix(postfix_dict)

        if epoch % 10 == 9:
            model.eval()
            auc_ = []
            prc_ = []
            for ims, labels in val_loader:
                ims, labels = ims.cuda(), labels.cuda()
                with torch.no_grad():
                    outputs = model(ims)
                auc_.append(get_auc(labels, outputs))
                prc_.append(get_prc(labels, outputs))
            print("EVAL")
            print(
                f"- [{epoch}]: AUC: {np.nanmean(auc_):0.4f} | PRC: {np.nanmean(prc_):0.4f}"
            )
            torch.save(model.state_dict(),
                       f"rml_{cfg.DATA.LABELED_SIZE}_{epoch:02d}.pth")
            # exit()

    t.close()

    torch.save(model.state_dict(), f"rml_{cfg.DATA.LABELED_SIZE}_final.pth")
Esempio n. 6
0
def PolarOffsetMain(args, cfg):
    if args.launcher == None:
        dist_train = False
    else:
        args.batch_size, cfg.LOCAL_RANK = getattr(
            common_utils, 'init_dist_%s' % args.launcher)(args.batch_size,
                                                          args.tcp_port,
                                                          args.local_rank,
                                                          backend='nccl')
        dist_train = True
    cfg['DIST_TRAIN'] = dist_train
    output_dir = os.path.join('./output', args.tag)
    ckpt_dir = os.path.join(output_dir, 'ckpt')
    tmp_dir = os.path.join(output_dir, 'tmp')
    summary_dir = os.path.join(output_dir, 'summary')
    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir, exist_ok=True)
    if not os.path.exists(tmp_dir):
        os.makedirs(tmp_dir, exist_ok=True)
    if not os.path.exists(summary_dir):
        os.makedirs(summary_dir, exist_ok=True)

    if args.onlyval and args.saveval:
        results_dir = os.path.join(output_dir, 'test', 'sequences')
        if not os.path.exists(results_dir):
            os.makedirs(results_dir, exist_ok=True)
        for i in range(8, 9):
            sub_dir = os.path.join(results_dir, str(i).zfill(2), 'predictions')
            if not os.path.exists(sub_dir):
                os.makedirs(sub_dir, exist_ok=True)

    if args.onlytest:
        results_dir = os.path.join(output_dir, 'test', 'sequences')
        if not os.path.exists(results_dir):
            os.makedirs(results_dir, exist_ok=True)
        for i in range(11, 22):
            sub_dir = os.path.join(results_dir, str(i).zfill(2), 'predictions')
            if not os.path.exists(sub_dir):
                os.makedirs(sub_dir, exist_ok=True)

    log_file = os.path.join(
        output_dir, ('log_train_%s.txt' %
                     datetime.datetime.now().strftime('%Y%m%d-%H%M%S')))
    logger = common_utils.create_logger(log_file, rank=cfg.LOCAL_RANK)

    logger.info('**********************Start logging**********************')
    gpu_list = os.environ[
        'CUDA_VISIBLE_DEVICES'] if 'CUDA_VISIBLE_DEVICES' in os.environ.keys(
        ) else 'ALL'
    logger.info('CUDA_VISIBLE_DEVICES=%s' % gpu_list)

    if dist_train:
        total_gpus = dist.get_world_size()
        logger.info('total_batch_size: %d' % (total_gpus * args.batch_size))
    for key, val in vars(args).items():
        logger.info('{:16} {}'.format(key, val))
    log_config_to_file(cfg, logger=logger)
    if cfg.LOCAL_RANK == 0:
        os.system('cp %s %s' % (args.config, output_dir))

    ### create dataloader
    if (not args.onlytest) and (not args.onlyval):
        train_dataset_loader = build_dataloader(args,
                                                cfg,
                                                split='train',
                                                logger=logger)
        val_dataset_loader = build_dataloader(args,
                                              cfg,
                                              split='val',
                                              logger=logger,
                                              no_shuffle=True,
                                              no_aug=True)
    elif args.onlyval:
        val_dataset_loader = build_dataloader(args,
                                              cfg,
                                              split='val',
                                              logger=logger,
                                              no_shuffle=True,
                                              no_aug=True)
    else:
        test_dataset_loader = build_dataloader(args,
                                               cfg,
                                               split='test',
                                               logger=logger,
                                               no_shuffle=True,
                                               no_aug=True)

    ### create model
    model = build_network(cfg)
    model.cuda()

    ### create optimizer
    optimizer = train_utils.build_optimizer(model, cfg)

    ### load ckpt
    ckpt_fname = os.path.join(ckpt_dir, args.ckpt_name)
    epoch = -1

    other_state = {}
    if args.pretrained_ckpt is not None and os.path.exists(ckpt_fname):
        logger.info(
            "Now in pretrain mode and loading ckpt: {}".format(ckpt_fname))
        if not args.nofix:
            if args.fix_semantic_instance:
                logger.info(
                    "Freezing backbone, semantic and instance part of the model."
                )
                model.fix_semantic_instance_parameters()
            else:
                logger.info(
                    "Freezing semantic and backbone part of the model.")
                model.fix_semantic_parameters()
        optimizer = train_utils.build_optimizer(model, cfg)
        epoch, other_state = train_utils.load_params_with_optimizer_otherstate(
            model,
            ckpt_fname,
            to_cpu=dist_train,
            optimizer=optimizer,
            logger=logger)  # new feature
        logger.info("Loaded Epoch: {}".format(epoch))
    elif args.pretrained_ckpt is not None:
        train_utils.load_pretrained_model(model,
                                          args.pretrained_ckpt,
                                          to_cpu=dist_train,
                                          logger=logger)
        if not args.nofix:
            if args.fix_semantic_instance:
                logger.info(
                    "Freezing backbone, semantic and instance part of the model."
                )
                model.fix_semantic_instance_parameters()
            else:
                logger.info(
                    "Freezing semantic and backbone part of the model.")
                model.fix_semantic_parameters()
        else:
            logger.info("No Freeze.")
        optimizer = train_utils.build_optimizer(model, cfg)
    elif os.path.exists(ckpt_fname):
        epoch, other_state = train_utils.load_params_with_optimizer_otherstate(
            model,
            ckpt_fname,
            to_cpu=dist_train,
            optimizer=optimizer,
            logger=logger)  # new feature
        logger.info("Loaded Epoch: {}".format(epoch))
    if other_state is None:
        other_state = {}

    ### create optimizer and scheduler
    lr_scheduler = None
    if lr_scheduler == None:
        logger.info('Not using lr scheduler')

    model.train(
    )  # before wrap to DistributedDataParallel to support fixed some parameters
    if dist_train:
        model = nn.parallel.DistributedDataParallel(
            model,
            device_ids=[cfg.LOCAL_RANK % torch.cuda.device_count()],
            find_unused_parameters=True)
    logger.info(model)

    if cfg.LOCAL_RANK == 0:
        writer = SummaryWriter(log_dir=summary_dir)

    logger.info('**********************Start Training**********************')
    rank = cfg.LOCAL_RANK
    best_before_iou = -1 if 'best_before_iou' not in other_state else other_state[
        'best_before_iou']
    best_pq = -1 if 'best_pq' not in other_state else other_state['best_pq']
    best_after_iou = -1 if 'best_after_iou' not in other_state else other_state[
        'best_after_iou']
    global_iter = 0 if 'global_iter' not in other_state else other_state[
        'global_iter']
    val_global_iter = 0 if 'val_global_iter' not in other_state else other_state[
        'val_global_iter']
    best_tracking_loss = 10086 if 'best_tracking_loss' not in other_state else other_state[
        'best_tracking_loss']

    ### test
    if args.onlytest:
        logger.info('----EPOCH {} Testing----'.format(epoch))
        model.eval()
        if rank == 0:
            vbar = tqdm(total=len(test_dataset_loader), dynamic_ncols=True)
        for i_iter, inputs in enumerate(test_dataset_loader):
            with torch.no_grad():
                if cfg.MODEL.NAME.startswith(
                        'PolarOffsetSpconvPytorchMeanshiftTracking'
                ) or cfg.MODEL.NAME.startswith('PolarOffsetSpconvTracking'):
                    ret_dict = model(inputs,
                                     is_test=True,
                                     merge_evaluator_list=None,
                                     merge_evaluator_window_k_list=None,
                                     require_cluster=True)
                else:
                    ret_dict = model(inputs,
                                     is_test=True,
                                     require_cluster=True,
                                     require_merge=True)
                common_utils.save_test_results(ret_dict, results_dir, inputs)
            if rank == 0:
                vbar.set_postfix({
                    'fname':
                    '/'.join(inputs['pcd_fname'][0].split('/')[-3:])
                })
                vbar.update(1)
        if rank == 0:
            vbar.close()
        logger.info("----Testing Finished----")
        return

    ### evaluate
    if args.onlyval:
        logger.info('----EPOCH {} Evaluating----'.format(epoch))
        model.eval()
        min_points = 50  # according to SemanticKITTI official rule
        if cfg.MODEL.NAME.startswith(
                'PolarOffsetSpconvPytorchMeanshiftTracking'
        ) or cfg.MODEL.NAME.startswith('PolarOffsetSpconvTracking'):
            merge_evaluator_list = []
            merge_evaluator_window_k_list = []
            for k in [1, 5, 10, 15]:
                merge_evaluator_list.append(init_eval(min_points))
                merge_evaluator_window_k_list.append(k)
        else:
            before_merge_evaluator = init_eval(min_points=min_points)
            after_merge_evaluator = init_eval(min_points=min_points)
        if rank == 0:
            vbar = tqdm(total=len(val_dataset_loader), dynamic_ncols=True)
        for i_iter, inputs in enumerate(val_dataset_loader):
            inputs['i_iter'] = i_iter
            # torch.cuda.empty_cache()
            with torch.no_grad():
                if cfg.MODEL.NAME.startswith(
                        'PolarOffsetSpconvPytorchMeanshiftTracking'
                ) or cfg.MODEL.NAME.startswith('PolarOffsetSpconvTracking'):
                    ret_dict = model(inputs,
                                     is_test=True,
                                     merge_evaluator_list=merge_evaluator_list,
                                     merge_evaluator_window_k_list=
                                     merge_evaluator_window_k_list,
                                     require_cluster=True)
                else:
                    ret_dict = model(
                        inputs,
                        is_test=True,
                        before_merge_evaluator=before_merge_evaluator,
                        after_merge_evaluator=after_merge_evaluator,
                        require_cluster=True)
                #########################
                # with open('./ipnb/{}_matching_list.pkl'.format(i_iter), 'wb') as fd:
                #     pickle.dump(ret_dict['matching_list'], fd)
                #########################
                if args.saveval:
                    common_utils.save_test_results(ret_dict, results_dir,
                                                   inputs)
            if rank == 0:
                vbar.set_postfix({
                    'loss':
                    ret_dict['loss'].item(),
                    'fname':
                    '/'.join(inputs['pcd_fname'][0].split('/')[-3:]),
                    'ins_num':
                    -1 if 'ins_num' not in ret_dict else ret_dict['ins_num']
                })
                vbar.update(1)
        if dist_train:
            if cfg.MODEL.NAME.startswith(
                    'PolarOffsetSpconvPytorchMeanshiftTracking'
            ) or cfg.MODEL.NAME.startswith('PolarOffsetSpconvTracking'):
                pass
            else:
                before_merge_evaluator = common_utils.merge_evaluator(
                    before_merge_evaluator, tmp_dir)
                dist.barrier()
                after_merge_evaluator = common_utils.merge_evaluator(
                    after_merge_evaluator, tmp_dir)

        if rank == 0:
            vbar.close()
        if rank == 0:
            ## print results
            if cfg.MODEL.NAME.startswith(
                    'PolarOffsetSpconvPytorchMeanshiftTracking'
            ) or cfg.MODEL.NAME.startswith('PolarOffsetSpconvTracking'):
                for evaluate, window_k in zip(merge_evaluator_list,
                                              merge_evaluator_window_k_list):
                    logger.info("Current Window K: {}".format(window_k))
                    printResults(evaluate, logger=logger)
            else:
                logger.info("Before Merge Semantic Scores")
                before_merge_results = printResults(before_merge_evaluator,
                                                    logger=logger,
                                                    sem_only=True)
                logger.info("After Merge Panoptic Scores")
                after_merge_results = printResults(after_merge_evaluator,
                                                   logger=logger)

        logger.info("----Evaluating Finished----")
        return

    ### train
    while True:
        epoch += 1
        if 'MAX_EPOCH' in cfg.OPTIMIZE.keys():
            if epoch > cfg.OPTIMIZE.MAX_EPOCH:
                break

        ### train one epoch
        logger.info('----EPOCH {} Training----'.format(epoch))
        loss_acc = 0
        if rank == 0:
            pbar = tqdm(total=len(train_dataset_loader), dynamic_ncols=True)
        for i_iter, inputs in enumerate(train_dataset_loader):
            # torch.cuda.empty_cache()
            torch.autograd.set_detect_anomaly(True)
            model.train()
            optimizer.zero_grad()
            inputs['i_iter'] = i_iter
            inputs['rank'] = rank
            ret_dict = model(inputs)

            if args.pretrained_ckpt is not None and not args.fix_semantic_instance:  # training offset
                if args.nofix:
                    loss = ret_dict['loss']
                elif len(ret_dict['offset_loss_list']) > 0:
                    loss = sum(ret_dict['offset_loss_list'])
                else:
                    loss = torch.tensor(0.0, requires_grad=True)  #mock pbar
                    ret_dict['offset_loss_list'] = [loss]  #mock writer
            elif args.pretrained_ckpt is not None and args.fix_semantic_instance and cfg.MODEL.NAME == 'PolarOffsetSpconvPytorchMeanshift':  # training dynamic shifting
                loss = sum(ret_dict['meanshift_loss'])
            elif cfg.MODEL.NAME.startswith(
                    'PolarOffsetSpconvPytorchMeanshiftTracking'
            ) or cfg.MODEL.NAME.startswith('PolarOffsetSpconvTracking'):
                loss = sum(ret_dict['tracking_loss'])
                #########################
                # with open('./ipnb/{}_matching_list.pkl'.format(i_iter), 'wb') as fd:
                #     pickle.dump(ret_dict['matching_list'], fd)
                #########################
            else:
                loss = ret_dict['loss']
            loss.backward()
            optimizer.step()

            if rank == 0:
                try:
                    cur_lr = float(optimizer.lr)
                except:
                    cur_lr = optimizer.param_groups[0]['lr']
                loss_acc += loss.item()
                pbar.set_postfix({
                    'loss': loss.item(),
                    'lr': cur_lr,
                    'mean_loss': loss_acc / float(i_iter + 1)
                })
                pbar.update(1)
                writer.add_scalar('Train/01_Loss', ret_dict['loss'].item(),
                                  global_iter)
                writer.add_scalar('Train/02_SemLoss',
                                  ret_dict['sem_loss'].item(), global_iter)
                if 'offset_loss_list' in ret_dict and sum(
                        ret_dict['offset_loss_list']).item() > 0:
                    writer.add_scalar('Train/03_InsLoss',
                                      sum(ret_dict['offset_loss_list']).item(),
                                      global_iter)
                writer.add_scalar('Train/04_LR', cur_lr, global_iter)
                writer_acc = 5
                if 'meanshift_loss' in ret_dict:
                    writer.add_scalar('Train/05_DSLoss',
                                      sum(ret_dict['meanshift_loss']).item(),
                                      global_iter)
                    writer_acc += 1
                if 'tracking_loss' in ret_dict:
                    writer.add_scalar('Train/06_TRLoss',
                                      sum(ret_dict['tracking_loss']).item(),
                                      global_iter)
                    writer_acc += 1
                more_keys = []
                for k, _ in ret_dict.items():
                    if k.find('summary') != -1:
                        more_keys.append(k)
                for ki, k in enumerate(more_keys):
                    if k == 'bandwidth_weight_summary':
                        continue
                    ki += writer_acc
                    writer.add_scalar(
                        'Train/{}_{}'.format(str(ki).zfill(2), k), ret_dict[k],
                        global_iter)
                global_iter += 1
        if rank == 0:
            pbar.close()

        ### evaluate after each epoch
        logger.info('----EPOCH {} Evaluating----'.format(epoch))
        model.eval()
        min_points = 50
        before_merge_evaluator = init_eval(min_points=min_points)
        after_merge_evaluator = init_eval(min_points=min_points)
        tracking_loss = 0
        if rank == 0:
            vbar = tqdm(total=len(val_dataset_loader), dynamic_ncols=True)
        for i_iter, inputs in enumerate(val_dataset_loader):
            # torch.cuda.empty_cache()
            inputs['i_iter'] = i_iter
            inputs['rank'] = rank
            with torch.no_grad():
                if cfg.MODEL.NAME.startswith(
                        'PolarOffsetSpconvPytorchMeanshiftTracking'
                ) or cfg.MODEL.NAME.startswith('PolarOffsetSpconvTracking'):
                    ret_dict = model(inputs,
                                     is_test=True,
                                     merge_evaluator_list=None,
                                     merge_evaluator_window_k_list=None,
                                     require_cluster=True)
                else:
                    ret_dict = model(
                        inputs,
                        is_test=True,
                        before_merge_evaluator=before_merge_evaluator,
                        after_merge_evaluator=after_merge_evaluator,
                        require_cluster=True)
            if rank == 0:
                vbar.set_postfix({'loss': ret_dict['loss'].item()})
                vbar.update(1)
                writer.add_scalar('Val/01_Loss', ret_dict['loss'].item(),
                                  val_global_iter)
                writer.add_scalar('Val/02_SemLoss',
                                  ret_dict['sem_loss'].item(), val_global_iter)
                if 'offset_loss_list' in ret_dict and sum(
                        ret_dict['offset_loss_list']).item() > 0:
                    writer.add_scalar('Val/03_InsLoss',
                                      sum(ret_dict['offset_loss_list']).item(),
                                      val_global_iter)
                if 'tracking_loss' in ret_dict:
                    writer.add_scalar('Val/06_TRLoss',
                                      sum(ret_dict['tracking_loss']).item(),
                                      global_iter)
                    tracking_loss += sum(ret_dict['tracking_loss']).item()
                more_keys = []
                for k, _ in ret_dict.items():
                    if k.find('summary') != -1:
                        more_keys.append(k)
                for ki, k in enumerate(more_keys):
                    if k == 'bandwidth_weight_summary':
                        continue
                    ki += 4
                    writer.add_scalar('Val/{}_{}'.format(str(ki).zfill(2), k),
                                      ret_dict[k], val_global_iter)
                val_global_iter += 1
        tracking_loss /= len(val_dataset_loader)
        if dist_train:
            try:
                before_merge_evaluator = common_utils.merge_evaluator(
                    before_merge_evaluator, tmp_dir, prefix='before_')
                dist.barrier()
                after_merge_evaluator = common_utils.merge_evaluator(
                    after_merge_evaluator, tmp_dir, prefix='after_')
            except:
                print("Someting went wrong when merging evaluator in rank {}".
                      format(rank))
        if rank == 0:
            vbar.close()
        if rank == 0:
            ## print results
            logger.info("Before Merge Semantic Scores")
            before_merge_results = printResults(before_merge_evaluator,
                                                logger=logger,
                                                sem_only=True)
            logger.info("After Merge Panoptic Scores")
            after_merge_results = printResults(after_merge_evaluator,
                                               logger=logger)
            ## save ckpt
            other_state = {
                'best_before_iou': best_before_iou,
                'best_pq': best_pq,
                'best_after_iou': best_after_iou,
                'global_iter': global_iter,
                'val_global_iter': val_global_iter,
                'best_tracking_loss': best_tracking_loss,
            }
            saved_flag = False
            if best_tracking_loss > tracking_loss and cfg.MODEL.NAME.startswith(
                    'PolarOffsetSpconvPytorchMeanshiftTracking'
            ) or cfg.MODEL.NAME.startswith('PolarOffsetSpconvTracking'):
                best_tracking_loss = tracking_loss
                if not saved_flag:
                    states = train_utils.checkpoint_state(
                        model, optimizer, epoch, other_state)
                    train_utils.save_checkpoint(
                        states,
                        os.path.join(
                            ckpt_dir, 'checkpoint_epoch_{}_{}.pth'.format(
                                epoch,
                                str(tracking_loss)[:5])))
                    saved_flag = True
            if best_before_iou < before_merge_results['iou_mean']:
                best_before_iou = before_merge_results['iou_mean']
                if not saved_flag:
                    states = train_utils.checkpoint_state(
                        model, optimizer, epoch, other_state)
                    train_utils.save_checkpoint(
                        states,
                        os.path.join(
                            ckpt_dir,
                            'checkpoint_epoch_{}_{}_{}_{}.pth'.format(
                                epoch,
                                str(best_before_iou)[:5],
                                str(best_pq)[:5],
                                str(best_after_iou)[:5])))
                    saved_flag = True
            if best_pq < after_merge_results['pq_mean']:
                best_pq = after_merge_results['pq_mean']
                if not saved_flag:
                    states = train_utils.checkpoint_state(
                        model, optimizer, epoch, other_state)
                    train_utils.save_checkpoint(
                        states,
                        os.path.join(
                            ckpt_dir,
                            'checkpoint_epoch_{}_{}_{}_{}.pth'.format(
                                epoch,
                                str(best_before_iou)[:5],
                                str(best_pq)[:5],
                                str(best_after_iou)[:5])))
                    saved_flag = True
            if best_after_iou < after_merge_results['iou_mean']:
                best_after_iou = after_merge_results['iou_mean']
                if not saved_flag:
                    states = train_utils.checkpoint_state(
                        model, optimizer, epoch, other_state)
                    train_utils.save_checkpoint(
                        states,
                        os.path.join(
                            ckpt_dir,
                            'checkpoint_epoch_{}_{}_{}_{}.pth'.format(
                                epoch,
                                str(best_before_iou)[:5],
                                str(best_pq)[:5],
                                str(best_after_iou)[:5])))
                    saved_flag = True
            logger.info("Current best before IoU: {}".format(best_before_iou))
            logger.info("Current best after IoU: {}".format(best_after_iou))
            logger.info("Current best after PQ: {}".format(best_pq))
            logger.info(
                "Current best tracking loss: {}".format(best_tracking_loss))
        if lr_scheduler != None:
            lr_scheduler.step(epoch)  # new feature
Esempio n. 7
0
def train(is_dist, start_epoch, local_rank):
    transforms = transform.build_transforms()
    coco_dataset = dataset.COCODataset(is_train=True, transforms=transforms)
    if (is_dist):
        sampler = distributedGroupSampler(coco_dataset)
    else:
        sampler = groupSampler(coco_dataset)
    dataloader = build_dataloader(coco_dataset, sampler)

    batch_time_meter = utils.AverageMeter()
    cls_loss_meter = utils.AverageMeter()
    reg_loss_meter = utils.AverageMeter()
    losses_meter = utils.AverageMeter()

    model = retinanet(is_train=True)
    if (start_epoch == 1):
        model.resnet.load_pretrained(pretrained_path[cfg.resnet_depth])
    else:
        utils.load_model(model, start_epoch - 1)
    model = model.cuda()

    if is_dist:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[
                local_rank,
            ],
            output_device=local_rank,
            broadcast_buffers=False)
    optimizer = solver.build_optimizer(model)
    scheduler = solver.scheduler(optimizer)

    model.train()
    logs = []

    for epoch in range(start_epoch, cfg.max_epochs + 1):
        if is_dist:
            dataloader.sampler.set_epoch(epoch - 1)
        scheduler.lr_decay(epoch)

        end_time = time.time()
        for iteration, datas in enumerate(dataloader, 1):
            scheduler.linear_warmup(epoch, iteration - 1)
            images = datas["images"]
            bboxes = datas["bboxes"]
            labels = datas["labels"]
            res_img_shape = datas["res_img_shape"]
            pad_img_shape = datas["pad_img_shape"]

            images = images.cuda()
            bboxes = [bbox.cuda() for bbox in bboxes]
            labels = [label.cuda() for label in labels]

            loss_dict = model(images,
                              gt_bboxes=bboxes,
                              gt_labels=labels,
                              res_img_shape=res_img_shape,
                              pad_img_shape=pad_img_shape)
            cls_loss = loss_dict["cls_loss"]
            reg_loss = loss_dict["reg_loss"]

            losses = cls_loss + reg_loss
            optimizer.zero_grad()
            losses.backward()
            optimizer.step()

            batch_time_meter.update(time.time() - end_time)
            end_time = time.time()

            cls_loss_meter.update(cls_loss.item())
            reg_loss_meter.update(reg_loss.item())
            losses_meter.update(losses.item())

            if (iteration % 50 == 0):
                if (local_rank == 0):
                    res = "\t".join([
                        "Epoch: [%d/%d]" % (epoch, cfg.max_epochs),
                        "Iter: [%d/%d]" % (iteration, len(dataloader)),
                        "Time: %.3f (%.3f)" %
                        (batch_time_meter.val, batch_time_meter.avg),
                        "cls_loss: %.4f (%.4f)" %
                        (cls_loss_meter.val, cls_loss_meter.avg),
                        "reg_loss: %.4f (%.4f)" %
                        (reg_loss_meter.val, reg_loss_meter.avg),
                        "Loss: %.4f (%.4f)" %
                        (losses_meter.val, losses_meter.avg),
                        "lr: %.6f" % (optimizer.param_groups[0]["lr"]),
                    ])
                    print(res)
                    logs.append(res)
                batch_time_meter.reset()
                cls_loss_meter.reset()
                reg_loss_meter.reset()
                losses_meter.reset()
        if (local_rank == 0):
            utils.save_model(model, epoch)
        if (is_dist):
            utils.synchronize()

    if (local_rank == 0):
        with open("logs.txt", "w") as f:
            for i in logs:
                f.write(i + "\n")
Esempio n. 8
0
        out = self.classifier(out)
        out = out.view(-1, batchsize, self.ncls)
        out = out.permute([1, 0, 2])
        return out


if __name__ == '__main__':

    from options import Options
    import cv2
    from dataloader import build_dataloader
    opt = Options().parse()
    opt.batch_size = 5

    nw = seq2seq(3, 128, 256).cuda()
    tr, vl = build_dataloader(opt)

    for index, (a, b) in enumerate(tr):

        # a: 5 x 10 x c x w x h
        # imgs = a[0].unbind(0)
        # imgs = list(map(lambda x: (x.permute([1, 2, 0]).numpy()*255).squeeze().astype(np.uint8), imgs))
        # for index, img in enumerate(imgs):
        #     cv2.imwrite('l_{}.png'.format(index), img)
        #     exit(1)
        y = nw(a)
        print(index)
        print(a.shape)
        print(b.shape)
        print("***")