コード例 #1
0
def taskonomy_features_transform_collated(task_path, dtype=np.float32):
    ''' rescale_centercrop_resize
    
        Args:
            output_size: A tuple CxWxH
            dtype: of the output (must be np, not torch)
            
        Returns:
            a function which returns takes 'env' and returns transform, output_size, dtype
    '''
    # print(task_path)
    net = TaskonomyEncoder().cuda()
    net.eval()
    checkpoint = torch.load(task_path)
    net.load_state_dict(checkpoint['state_dict'])

    def encode(x):
        with torch.no_grad():
            x = torch.Tensor(x).cuda()
            if isinstance(x, torch.Tensor):  # for training
                x = torch.cuda.FloatTensor(x.cuda())
            else:  # for testing
                x = torch.cuda.FloatTensor(x).cuda()

            x = x.permute(0, 3, 1, 2) / 255.0  #.view(1, 3, 256, 256)
            x = 2.0 * x - 1.0
            return net(x)

    def _taskonomy_features_transform_thunk(obs_space):
        pipeline = lambda x: encode(x).cpu()
        return pipeline, spaces.Box(-1, 1, (8, 16, 16), dtype)

    return _taskonomy_features_transform_thunk
コード例 #2
0
def taskonomy_features_transform(task_path, dtype=np.float32):
    ''' rescale_centercrop_resize
    
        Args:
            output_size: A tuple CxWxH
            dtype: of the output (must be np, not torch)
            
        Returns:
            a function which returns takes 'env' and returns transform, output_size, dtype
    '''
    # print(task_path)
    net = TaskonomyEncoder().cuda()
    net.eval()
    checkpoint = torch.load(task_path)
    net.load_state_dict(checkpoint['state_dict'])

    def encode(x):
        with torch.no_grad():
            return net(x)

    def _taskonomy_features_transform_thunk(obs_space):
        def pipeline(x):
            #             print(x.shape, x.min(), x.max())
            x = torch.Tensor(x).cuda()
            x = encode(x)
            return x.cpu()

        return pipeline, spaces.Box(-1, 1, (8, 16, 16), dtype)

    return _taskonomy_features_transform_thunk
コード例 #3
0
ファイル: transforms.py プロジェクト: lilujunai/side-tuning
def _load_encoder(encoder_path):
    if 'student' in encoder_path or 'distil' in encoder_path:
        net = FCN5(normalize_outputs=True, eval_only=True, train=False)
    else:
        net = TaskonomyEncoder()  #.cuda()
    net.eval()
    checkpoint = torch.load(encoder_path)
    state_dict = checkpoint['state_dict']
    try:
        net.load_state_dict(state_dict, strict=True)
    except RuntimeError as e:
        incompatible = net.load_state_dict(state_dict, strict=False)
        if incompatible is None:
            warnings.warn(
                'load_state_dict not showing missing/unexpected keys!')
        else:
            print(
                f'{e}, reloaded with strict=False \n'
                f'Num matches: {len([k for k in net.state_dict() if k in state_dict])}\n'
                f'Num missing: {len(incompatible.missing_keys)} \n'
                f'Num unexpected: {len(incompatible.unexpected_keys)}')
    for p in net.parameters():
        p.requires_grad = False
    # net = Compose(nn.GroupNorm(32, 32, affine=False), net)
    return net
コード例 #4
0
def _load_encoder(encoder_path):
    net = TaskonomyEncoder()  #.cuda()
    net.eval()
    checkpoint = torch.load(encoder_path)
    net.load_state_dict(checkpoint['state_dict'])
    for p in net.parameters():
        p.requires_grad = False
    # net = Compose(nn.GroupNorm(32, 32, affine=False), net)
    return net