Пример #1
0
def _load_encoder(encoder_path):
    if 'student' in encoder_path or 'distil' in encoder_path:
        net = FCN5(normalize_outputs=True, eval_only=True, train=False)
    else:
        net = TaskonomyEncoder()  #.cuda()
    net.eval()
    checkpoint = torch.load(encoder_path)
    state_dict = checkpoint['state_dict']
    try:
        net.load_state_dict(state_dict, strict=True)
    except RuntimeError as e:
        incompatible = net.load_state_dict(state_dict, strict=False)
        if incompatible is None:
            warnings.warn(
                'load_state_dict not showing missing/unexpected keys!')
        else:
            print(
                f'{e}, reloaded with strict=False \n'
                f'Num matches: {len([k for k in net.state_dict() if k in state_dict])}\n'
                f'Num missing: {len(incompatible.missing_keys)} \n'
                f'Num unexpected: {len(incompatible.unexpected_keys)}')
    for p in net.parameters():
        p.requires_grad = False
    # net = Compose(nn.GroupNorm(32, 32, affine=False), net)
    return net
def _load_encoder(encoder_path):
    net = TaskonomyEncoder()  #.cuda()
    net.eval()
    checkpoint = torch.load(encoder_path)
    net.load_state_dict(checkpoint['state_dict'])
    for p in net.parameters():
        p.requires_grad = False
    # net = Compose(nn.GroupNorm(32, 32, affine=False), net)
    return net
def taskonomy_features_transform_collated(task_path, dtype=np.float32):
    ''' rescale_centercrop_resize
    
        Args:
            output_size: A tuple CxWxH
            dtype: of the output (must be np, not torch)
            
        Returns:
            a function which returns takes 'env' and returns transform, output_size, dtype
    '''
    # print(task_path)
    net = TaskonomyEncoder().cuda()
    net.eval()
    checkpoint = torch.load(task_path)
    net.load_state_dict(checkpoint['state_dict'])

    def encode(x):
        with torch.no_grad():
            x = torch.Tensor(x).cuda()
            if isinstance(x, torch.Tensor):  # for training
                x = torch.cuda.FloatTensor(x.cuda())
            else:  # for testing
                x = torch.cuda.FloatTensor(x).cuda()

            x = x.permute(0, 3, 1, 2) / 255.0  #.view(1, 3, 256, 256)
            x = 2.0 * x - 1.0
            return net(x)

    def _taskonomy_features_transform_thunk(obs_space):
        pipeline = lambda x: encode(x).cpu()
        return pipeline, spaces.Box(-1, 1, (8, 16, 16), dtype)

    return _taskonomy_features_transform_thunk
def taskonomy_features_transform(task_path, dtype=np.float32):
    ''' rescale_centercrop_resize
    
        Args:
            output_size: A tuple CxWxH
            dtype: of the output (must be np, not torch)
            
        Returns:
            a function which returns takes 'env' and returns transform, output_size, dtype
    '''
    # print(task_path)
    net = TaskonomyEncoder().cuda()
    net.eval()
    checkpoint = torch.load(task_path)
    net.load_state_dict(checkpoint['state_dict'])

    def encode(x):
        with torch.no_grad():
            return net(x)

    def _taskonomy_features_transform_thunk(obs_space):
        def pipeline(x):
            #             print(x.shape, x.min(), x.max())
            x = torch.Tensor(x).cuda()
            x = encode(x)
            return x.cpu()

        return pipeline, spaces.Box(-1, 1, (8, 16, 16), dtype)

    return _taskonomy_features_transform_thunk
def taskonomy_features_transforms_collated(task_paths,
                                           encoder_type='taskonomy',
                                           dtype=np.float32):
    # handles multiple taskonomy encoders at once
    num_tasks = 0
    if task_paths != 'pixels_as_state' and task_paths != 'blind':
        task_path_list = [tp.strip() for tp in task_paths.split(',')]
        num_tasks = len(task_path_list)
        assert num_tasks > 0, 'at least need one path'
        if encoder_type == 'taskonomy':
            nets = [
                TaskonomyEncoder(normalize_outputs=False)
                for _ in range(num_tasks)
            ]
        else:
            assert False, f'do not recongize encoder type {encoder_type}'
        for i, task_path in enumerate(task_path_list):
            checkpoint = torch.load(task_path)
            net_in_ckpt = [
                v for v in checkpoint.values() if isinstance(v, nn.Module)
            ]
            if len(net_in_ckpt) > 0:
                nets[i] = net_in_ckpt[0]
            elif 'state_dict' in checkpoint.keys():
                nets[i].load_state_dict(checkpoint['state_dict'])
            else:
                assert False, f'Cannot read task_path {task_path}, no nn.Module or state_dict found. Encoder_type is {encoder_type}'
            nets[i] = nets[i].cuda()
            nets[i].eval()

    def encode(x):
        if task_paths == 'pixels_as_state' or task_paths == 'blind':
            return x
        with torch.no_grad():
            feats = []
            for net in nets:
                feats.append(net(x))
            return torch.cat(feats, dim=1)

    def _taskonomy_features_transform_thunk(obs_space):
        def pipeline(x):
            with torch.no_grad():
                if isinstance(x, torch.Tensor):  # for training
                    x = torch.cuda.FloatTensor(x.cuda())
                else:  # for testing
                    x = torch.cuda.FloatTensor(x).cuda()

                x = x.permute(0, 3, 1, 2) / 255.0  #.view(1, 3, 256, 256)
                x = 2.0 * x - 1.0
                x = encode(x)
                return x

        def pixels_as_state_pipeline(x):
            return pixels_as_state(x).cpu()

        if task_path == 'pixels_as_state':
            return pixels_as_state_pipeline, spaces.Box(
                -1, 1, (8, 16, 16), dtype)
        else:
            return pipeline, spaces.Box(-1, 1, (8 * num_tasks, 16, 16), dtype)

    return _taskonomy_features_transform_thunk