Пример #1
0
    def __init__(
        self,
        cfg_file_path,
        model_weight_url,
        detect_rate,
        common_cate,
        device,
        exclude_class=[],
    ):
        # TODO: add exclude class
        cfg = base_cfg.clone()
        cfg.merge_from_file(cfg_file_path)
        cfg.MODEL.WEIGHT = model_weight_url
        cfg.IA_STRUCTURE.MEMORY_RATE *= detect_rate
        if common_cate:
            cfg.MODEL.ROI_ACTION_HEAD.NUM_CLASSES = 15
            cfg.MODEL.ROI_ACTION_HEAD.NUM_PERSON_MOVEMENT_CLASSES = 6
            cfg.MODEL.ROI_ACTION_HEAD.NUM_OBJECT_MANIPULATION_CLASSES = 5
            cfg.MODEL.ROI_ACTION_HEAD.NUM_PERSON_INTERACTION_CLASSES = 4
        cfg.freeze()
        self.cfg = cfg

        self.model = build_detection_model(cfg)
        self.model.eval()
        self.model.to(device)

        save_dir = cfg.OUTPUT_DIR
        checkpointer = ActionCheckpointer(cfg, self.model, save_dir=save_dir)
        self.mem_pool = MemoryPool()
        self.object_pool = MemoryPool()
        self.timestamps = []
        _ = checkpointer.load(cfg.MODEL.WEIGHT)

        self.transforms, self.person_transforms, self.object_transforms = self.build_transform(
        )

        self.device = device
        self.cpu_device = torch.device("cpu")
        self.exclude_class = exclude_class
def main():
    args = parse_args()
    update_config(cfg, args)

    logger, final_output_dir, tb_log_dir = create_logger(
        cfg, args.cfg, 'train')

    logger.info(pprint.pformat(args))
    logger.info(cfg)

    t_checkpoints = cfg.KD.TEACHER  #注意是在student配置文件中修改
    train_type = cfg.KD.TRAIN_TYPE  #注意是在student配置文件中修改
    train_type = get_train_type(train_type, t_checkpoints)
    logger.info('=> train type is {} '.format(train_type))

    if train_type == 'FPD':
        cfg_name = 'student_' + os.path.basename(args.cfg).split('.')[0]
    else:
        cfg_name = os.path.basename(args.cfg).split('.')[0]
    save_yaml_file(cfg_name, cfg, final_output_dir)

    # cudnn related setting
    cudnn.benchmark = cfg.CUDNN.BENCHMARK
    torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC
    torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

    model = eval('models.' + cfg.MODEL.NAME + '.get_pose_net')(cfg,
                                                               is_train=True)

    # fpd method, default NORMAL
    if train_type == 'FPD':
        tcfg = cfg.clone()
        tcfg.defrost()
        tcfg.merge_from_file(args.tcfg)
        tcfg.freeze()
        tcfg_name = 'teacher_' + os.path.basename(args.tcfg).split('.')[0]
        save_yaml_file(tcfg_name, tcfg, final_output_dir)
        # teacher model
        tmodel = eval('models.' + tcfg.MODEL.NAME + '.get_pose_net')(
            tcfg, is_train=False)

        load_checkpoint(t_checkpoints,
                        tmodel,
                        strict=True,
                        model_info='teacher_' + tcfg.MODEL.NAME)

        tmodel = torch.nn.DataParallel(tmodel, device_ids=cfg.GPUS).cuda()
        # define kd_pose loss function (criterion) and optimizer
        kd_pose_criterion = JointsMSELoss(
            use_target_weight=tcfg.LOSS.USE_TARGET_WEIGHT).cuda()

    # copy model file
    this_dir = os.path.dirname(__file__)
    shutil.copy2(
        os.path.join(this_dir, '../lib/models', cfg.MODEL.NAME + '.py'),
        final_output_dir)
    # logger.info(pprint.pformat(model))

    writer_dict = {
        'writer': SummaryWriter(log_dir=tb_log_dir),
        'train_global_steps': 0,
        'valid_global_steps': 0,
    }

    dump_input = torch.rand(
        (1, 3, cfg.MODEL.IMAGE_SIZE[1], cfg.MODEL.IMAGE_SIZE[0]))
    writer_dict['writer'].add_graph(model, (dump_input, ))

    logger.info(get_model_summary(model, dump_input))

    if cfg.TRAIN.CHECKPOINT:
        load_checkpoint(cfg.TRAIN.CHECKPOINT,
                        model,
                        strict=True,
                        model_info='student_' + cfg.MODEL.NAME)
    model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()

    # you can choose or replace pose_loss and kd_pose_loss type, including mse,kl,ohkm loss ect
    # define pose loss function (criterion) and optimizer
    pose_criterion = JointsMSELoss(
        use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT).cuda()

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_SET, True,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))
    valid_dataset = eval('dataset.' + cfg.DATASET.DATASET)(
        cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
        shuffle=cfg.TRAIN.SHUFFLE,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY)
    valid_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=cfg.TEST.BATCH_SIZE_PER_GPU * len(cfg.GPUS),
        shuffle=False,
        num_workers=cfg.WORKERS,
        pin_memory=cfg.PIN_MEMORY)

    best_perf = 0.0
    best_model = False
    last_epoch = -1
    optimizer = get_optimizer(cfg, model)
    begin_epoch = cfg.TRAIN.BEGIN_EPOCH
    checkpoint_file = os.path.join(final_output_dir, 'checkpoint.pth')

    if cfg.AUTO_RESUME and os.path.exists(checkpoint_file):
        logger.info("=> loading checkpoint '{}'".format(checkpoint_file))
        checkpoint = torch.load(checkpoint_file)
        begin_epoch = checkpoint['epoch']
        best_perf = checkpoint['perf']
        last_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])

        optimizer.load_state_dict(checkpoint['optimizer'])
        logger.info("=> loaded checkpoint '{}' (epoch {})".format(
            checkpoint_file, checkpoint['epoch']))

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                        cfg.TRAIN.LR_STEP,
                                                        cfg.TRAIN.LR_FACTOR,
                                                        last_epoch=last_epoch)

    # evaluate on validation set
    validate(cfg, valid_loader, valid_dataset, tmodel, pose_criterion,
             final_output_dir, tb_log_dir, writer_dict)
    validate(cfg, valid_loader, valid_dataset, model, pose_criterion,
             final_output_dir, tb_log_dir, writer_dict)

    for epoch in range(begin_epoch, cfg.TRAIN.END_EPOCH):
        lr_scheduler.step()

        # fpd method, default NORMAL
        if train_type == 'FPD':
            # train for one epoch
            fpd_train(cfg, train_loader, model, tmodel, pose_criterion,
                      kd_pose_criterion, optimizer, epoch, final_output_dir,
                      tb_log_dir, writer_dict)
        else:
            # train for one epoch
            train(cfg, train_loader, model, pose_criterion, optimizer, epoch,
                  final_output_dir, tb_log_dir, writer_dict)

        # evaluate on validation set
        perf_indicator = validate(cfg, valid_loader, valid_dataset, model,
                                  pose_criterion, final_output_dir, tb_log_dir,
                                  writer_dict)

        if perf_indicator >= best_perf:
            best_perf = perf_indicator
            best_model = True
        else:
            best_model = False

        logger.info('=> saving checkpoint to {}'.format(final_output_dir))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': cfg.MODEL.NAME,
                'state_dict': model.state_dict(),
                'best_state_dict': model.module.state_dict(),
                'perf': perf_indicator,
                'optimizer': optimizer.state_dict(),
            }, best_model, final_output_dir)

    final_model_state_file = os.path.join(final_output_dir, 'final_state.pth')
    logger.info(
        '=> saving final model state to {}'.format(final_model_state_file))
    torch.save(model.module.state_dict(), final_model_state_file)
    writer_dict['writer'].close()
Пример #3
0
import pprint

from config import cfg

if __name__ == "__main__":
    cfg.merge_from_file("config.yaml")
    cfg.freeze()

    cfg2 = cfg.clone()
    cfg2.defrost()
    cfg2.TRAIN.SCALES = (8, 16, 32)
    cfg2.freeze()

    print("cfg:")
    pprint.pprint(cfg)
    print("cfg2:")
    pprint.pprint(cfg2)
Пример #4
0
        os.makedirs(output_base_path, exist_ok=True)

        print("Moving dataset to target directory...")
        train_path = utils.abs_or_offset_from(cfg.train_data.loader.path,
                                              cfg_dir)
        val_path = utils.abs_or_offset_from(cfg.valid_data.loader.path,
                                            cfg_dir)
        test_path = utils.abs_or_offset_from(cfg.test_data.loader.path,
                                             cfg_dir)
        shutil.copy(train_path, output_base_path)
        if train_path != val_path:
            shutil.copy(val_path, output_base_path)
        if test_path != val_path:
            shutil.copy(test_path, output_base_path)

        cfg_new = cfg.clone()
        cfg_new.train_data.loader.path = os.path.basename(
            cfg.train_data.loader.path)
        cfg_new.valid_data.loader.path = os.path.basename(
            cfg.valid_data.loader.path)
        cfg_new.test_data.loader.path = os.path.basename(
            cfg.test_data.loader.path)
        cfg_new.output.base_path = "."
        cfg_new.skip_copy = True

        print("Writing new configuration file...")
        with open(os.path.join(output_base_path, "config.yaml"), "w") as f:
            f.write(cfg_new.dump())

        print("Writing normallization file...")
        norm_node = CfgNode()