Exemplo n.º 1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_rank", type=int, default=0)
    # parser.add_argument("--iter", "-i", type=int, default=-1)
    args = parser.parse_args()

    num_gpus = int(
        os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    distributed = num_gpus > 1

    if distributed:
        torch.cuda.set_device(args.local_rank)
        dist.init_process_group(backend="nccl", init_method="env://")
        synchronize()

    if is_main_process() and not os.path.exists(cfg.PRESENT_DIR):
        os.mkdir(cfg.PRESENT_DIR)
    logger = get_logger(
        cfg.DATASET.NAME, cfg.PRESENT_DIR, args.local_rank, 'present_log.txt')

    # if args.iter == -1:
    #     logger.info("Please designate one iteration.")

    model = MSPN(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(cfg.MODEL.DEVICE)

    model_file = "/home/zqr/codes/MSPN/lib/models/mspn_2xstg_coco.pth"
    if os.path.exists(model_file):
        state_dict = torch.load(
            model_file, map_location=lambda storage, loc: storage)
        state_dict = state_dict['model']
        model.load_state_dict(state_dict)

    data_loader = get_present_loader(cfg, num_gpus, args.local_rank, cfg.INFO_PATH,
                                     is_dist=distributed)

    results = inference(model, data_loader, logger, device)
    synchronize()

    if is_main_process():
        logger.info("Dumping results ...")
        results.sort(
            key=lambda res: (res['image_id'], res['score']), reverse=True)
        results_path = os.path.join(cfg.PRESENT_DIR, 'results.json')
        with open(results_path, 'w') as f:
            json.dump(results, f)
        logger.info("Get all results.")
        for res in results:
            data_numpy = cv2.imread(os.path.join(
                cfg.IMG_FOLDER, res['image_id']), cv2.IMREAD_COLOR)
            img = data_loader.ori_dataset.visualize(
                data_numpy, res['keypoints'], res['score'])
            cv2.imwrite(os.path.join(cfg.PRESENT_DIR, res['image_id']), img)
Exemplo n.º 2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_rank", type=int, default=0)
    parser.add_argument("--iter", "-i", type=int, default=-1)
    args = parser.parse_args()

    num_gpus = int(
        os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    distributed = num_gpus > 1

    if distributed:
        torch.cuda.set_device(args.local_rank)
        dist.init_process_group(backend="nccl", init_method="env://")
        synchronize()

    if is_main_process() and not os.path.exists(cfg.TEST_DIR):
        os.mkdir(cfg.TEST_DIR)
    logger = get_logger(cfg.DATASET.NAME, cfg.TEST_DIR, args.local_rank,
                        'test_log.txt')

    if args.iter == -1:
        logger.info("Please designate one iteration.")

    model = MSPN(cfg)
    device = torch.device(cfg.MODEL.DEVICE)
    model.to(cfg.MODEL.DEVICE)

    model_file = os.path.join(cfg.OUTPUT_DIR, "iter-{}.pth".format(args.iter))
    if os.path.exists(model_file):
        state_dict = torch.load(model_file,
                                map_location=lambda storage, loc: storage)
        state_dict = state_dict['model']
        model.load_state_dict(state_dict)

    data_loader = get_test_loader(cfg,
                                  num_gpus,
                                  args.local_rank,
                                  'val',
                                  is_dist=distributed)

    results = inference(model, data_loader, logger, device)
    synchronize()

    if is_main_process():
        logger.info("Dumping results ...")
        results.sort(key=lambda res: (res['image_id'], res['score']),
                     reverse=True)
        results_path = os.path.join(cfg.TEST_DIR, 'results.json')
        with open(results_path, 'w') as f:
            json.dump(results, f)
        logger.info("Get all results.")

        data_loader.ori_dataset.evaluate(results_path)
Exemplo n.º 3
0
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--image",
        default=
        "https://miro.medium.com/max/1200/1*56MtNM2fh_mdG3iGnD7_ZQ.jpeg",
        help="pass any image URL")
    parser.add_argument(
        '--model_path',
        default="drive-download-20210107T043037Z-001/mspn_2xstg_coco.pth")
    args = parser.parse_args()

    print("-" * 70)
    print(":: Loading the model")
    cfg = get_config("coco")
    model = MSPN(cfg)
    state_dict = torch.load(
        "drive-download-20210107T043037Z-001/mspn_2xstg_coco.pth",
        map_location="cpu")
    state_dict = state_dict['model']
    model.load_state_dict(state_dict)

    # define the image transformations to apply to each image
    image_transformations = transforms.Compose([
        transforms.Resize((256, 192)),  # resize to a (256,192) image
        transforms.ToTensor(),  # convert to tensor
        transforms.Normalize(
            cfg.INPUT.MEANS,
            cfg.INPUT.STDS),  # normalise image according to imagenet valuess
    ])
Exemplo n.º 4
0
def main():
    parser = argparse.ArgumentParser()

    with Engine(cfg, custom_parser=parser) as engine:
        logger = engine.setup_log(
            name='train', log_dir=cfg.OUTPUT_DIR, file_name='log.txt')
        args = parser.parse_args()
        ensure_dir(cfg.OUTPUT_DIR)

        model = MSPN(cfg, run_efficient=cfg.RUN_EFFICIENT)
        device = torch.device(cfg.MODEL.DEVICE)
        model.to(device)

        num_gpu = len(engine.devices) 
        # default num_gpu: 8, adjust iter settings
        cfg.SOLVER.CHECKPOINT_PERIOD = \
                int(cfg.SOLVER.CHECKPOINT_PERIOD * 8 / num_gpu)
        cfg.SOLVER.MAX_ITER = int(cfg.SOLVER.MAX_ITER * 8 / num_gpu)
        optimizer = make_optimizer(cfg, model, num_gpu)
        scheduler = make_lr_scheduler(cfg, optimizer)

        engine.register_state(
            scheduler=scheduler, model=model, optimizer=optimizer)

        if engine.distributed:
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.local_rank],
                broadcast_buffers=False, )

        if engine.continue_state_object:
            engine.restore_checkpoint(is_restore=False)
        else:
            if cfg.MODEL.WEIGHT:
                engine.load_checkpoint(cfg.MODEL.WEIGHT, is_restore=False)

        data_loader = get_train_loader(cfg, num_gpu=num_gpu, is_dist=engine.distributed)
        #print(data_loader[0].shape)

        # ------------ do training ---------------------------- #
        logger.info("\n\nStart training with pytorch version {}".format(
            torch.__version__))

        max_iter = len(data_loader)
        checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD
        tb_writer = SummaryWriter(cfg.TENSORBOARD_DIR)

        model.train()

        time1 = time.time()
        for iteration, (images, valids, labels) in enumerate(
                data_loader, engine.state.iteration):
            iteration = iteration + 1
            images = images.to(device)
            valids = valids.to(device)
            labels = labels.to(device)


            loss_dict = model(images, valids, labels)
            losses = sum(loss for loss in loss_dict.values())

            optimizer.zero_grad()
            losses.backward()
            optimizer.step()

            scheduler.step()

            if cfg.RUN_EFFICIENT:
                del images, valids, labels, losses

            if engine.local_rank == 0:
                if iteration % 20 == 0 or iteration == max_iter:
                    log_str = 'Iter:%d, LR:%.1e, ' % (
                        iteration, optimizer.param_groups[0]["lr"] / num_gpu)
                    for key in loss_dict:
                        tb_writer.add_scalar(
                            key,  loss_dict[key].mean(), global_step=iteration)
                        log_str += key + ': %.3f, ' % float(loss_dict[key])

                    time2 = time.time()
                    elapsed_time = time2 - time1
                    time1 = time2
                    required_time = elapsed_time / 20 * (max_iter - iteration)
                    hours = required_time // 3600
                    mins = required_time % 3600 // 60
                    log_str += 'To Finish: %dh%dmin,' % (hours, mins) 

                    logger.info(log_str)

            #print(iteration % 20)
            if iteration % 2000 == 0 or iteration == max_iter:
                engine.update_iteration(iteration)
                print("--------")
                if engine.distributed and (engine.local_rank == 0):
                    engine.save_and_link_checkpoint(cfg.OUTPUT_DIR)
                elif not engine.distributed:
                    engine.save_and_link_checkpoint(cfg.OUTPUT_DIR)

            if iteration >= max_iter:
                logger.info('Finish training process!')
                break