コード例 #1
0
ファイル: transforms.py プロジェクト: lilujunai/side-tuning
def _load_encoder(encoder_path):
    if 'student' in encoder_path or 'distil' in encoder_path:
        net = FCN5(normalize_outputs=True, eval_only=True, train=False)
    else:
        net = TaskonomyEncoder()  #.cuda()
    net.eval()
    checkpoint = torch.load(encoder_path)
    state_dict = checkpoint['state_dict']
    try:
        net.load_state_dict(state_dict, strict=True)
    except RuntimeError as e:
        incompatible = net.load_state_dict(state_dict, strict=False)
        if incompatible is None:
            warnings.warn(
                'load_state_dict not showing missing/unexpected keys!')
        else:
            print(
                f'{e}, reloaded with strict=False \n'
                f'Num matches: {len([k for k in net.state_dict() if k in state_dict])}\n'
                f'Num missing: {len(incompatible.missing_keys)} \n'
                f'Num unexpected: {len(incompatible.unexpected_keys)}')
    for p in net.parameters():
        p.requires_grad = False
    # net = Compose(nn.GroupNorm(32, 32, affine=False), net)
    return net
コード例 #2
0
def _load_encoder(encoder_path):
    net = TaskonomyEncoder()  #.cuda()
    net.eval()
    checkpoint = torch.load(encoder_path)
    net.load_state_dict(checkpoint['state_dict'])
    for p in net.parameters():
        p.requires_grad = False
    # net = Compose(nn.GroupNorm(32, 32, affine=False), net)
    return net