예제 #1
0
def set_model():
    net = model_factory[cfg.model_type](cfg.n_classes)
    if not args.finetune_from is None:
        net.load_state_dict(torch.load(args.finetune_from, map_location='cpu'))
    if cfg.use_sync_bn: net = nn.SyncBatchNorm.convert_sync_batchnorm(net)
    net.cuda()
    net.train()
    criteria_pre = OhemCELoss(0.7)
    criteria_aux = [OhemCELoss(0.7) for _ in range(cfg.num_aux_heads)]
    return net, criteria_pre, criteria_aux
예제 #2
0
def set_model():
    net = model_factory[cfg.model_type](19)
    if not args.finetune_from is None:
        checkpoint = torch.load(args.finetune_from, map_location='cpu')
        net.load_state_dict(checkpoint['state_dict'])

    net.to(device)
    net.train()
    criteria_pre = OhemCELoss(0.7)
    criteria_aux = [OhemCELoss(0.7) for _ in range(cfg.num_aux_heads)]
    return net, criteria_pre, criteria_aux
예제 #3
0
파일: train.py 프로젝트: KoskHrd/BiSeNet
def set_model():
    # net = model_factory[cfg.model_type](n_classes=19)
    net = model_factory[cfg.model_type](n_classes=cfg.n_classes)
    if not args.finetune_from is None:
        net.load_state_dict(torch.load(args.finetune_from, map_location='cpu'))
    if cfg.use_sync_bn: net = set_syncbn(net)
    net.to(device)
    net.train()
    #CHANGED: undo to use normal CrossEntropyLoss
    #FIXME: learn how to use OhemCrossEntropyLoss (Online Hard Example Mining)
    criteria_pre = OhemCELoss(0.7, cfg.anns_ignore)
    criteria_aux = [
        OhemCELoss(0.7, cfg.anns_ignore) for _ in range(cfg.num_aux_heads)
    ]
    # criteria_pre = nn.CrossEntropyLoss(ignore_index=cfg.anns_ignore)
    # criteria_aux = [nn.CrossEntropyLoss(ignore_index=cfg.anns_ignore) for _ in range(cfg.num_aux_heads)]
    return net, criteria_pre, criteria_aux
예제 #4
0
def set_model():
    net = model_factory[cfg.model_type](cfg.n_classes)

    if cfg.model_type == 'hardnet':
        net.apply(weights_init)
        pretrained_path = './hardnet_weights/hardnet_petite_base.pth'
        weights = torch.load(pretrained_path)
        net.base.load_state_dict(weights)

    if not args.finetune_from is None:
        net.load_state_dict(torch.load(args.finetune_from, map_location='cpu'))
    if cfg.use_sync_bn: net = set_syncbn(net)
    net.cuda()
    net.train()
    criteria_pre = OhemCELoss(0.7)
    criteria_aux = [OhemCELoss(0.7) for _ in range(cfg.num_aux_heads)]
    return net, criteria_pre, criteria_aux