def taskonomy_features_transform_collated(task_path, dtype=np.float32): ''' rescale_centercrop_resize Args: output_size: A tuple CxWxH dtype: of the output (must be np, not torch) Returns: a function which returns takes 'env' and returns transform, output_size, dtype ''' # print(task_path) net = TaskonomyEncoder().cuda() net.eval() checkpoint = torch.load(task_path) net.load_state_dict(checkpoint['state_dict']) def encode(x): with torch.no_grad(): x = torch.Tensor(x).cuda() if isinstance(x, torch.Tensor): # for training x = torch.cuda.FloatTensor(x.cuda()) else: # for testing x = torch.cuda.FloatTensor(x).cuda() x = x.permute(0, 3, 1, 2) / 255.0 #.view(1, 3, 256, 256) x = 2.0 * x - 1.0 return net(x) def _taskonomy_features_transform_thunk(obs_space): pipeline = lambda x: encode(x).cpu() return pipeline, spaces.Box(-1, 1, (8, 16, 16), dtype) return _taskonomy_features_transform_thunk
def taskonomy_features_transform(task_path, dtype=np.float32): ''' rescale_centercrop_resize Args: output_size: A tuple CxWxH dtype: of the output (must be np, not torch) Returns: a function which returns takes 'env' and returns transform, output_size, dtype ''' # print(task_path) net = TaskonomyEncoder().cuda() net.eval() checkpoint = torch.load(task_path) net.load_state_dict(checkpoint['state_dict']) def encode(x): with torch.no_grad(): return net(x) def _taskonomy_features_transform_thunk(obs_space): def pipeline(x): # print(x.shape, x.min(), x.max()) x = torch.Tensor(x).cuda() x = encode(x) return x.cpu() return pipeline, spaces.Box(-1, 1, (8, 16, 16), dtype) return _taskonomy_features_transform_thunk
def _load_encoder(encoder_path): if 'student' in encoder_path or 'distil' in encoder_path: net = FCN5(normalize_outputs=True, eval_only=True, train=False) else: net = TaskonomyEncoder() #.cuda() net.eval() checkpoint = torch.load(encoder_path) state_dict = checkpoint['state_dict'] try: net.load_state_dict(state_dict, strict=True) except RuntimeError as e: incompatible = net.load_state_dict(state_dict, strict=False) if incompatible is None: warnings.warn( 'load_state_dict not showing missing/unexpected keys!') else: print( f'{e}, reloaded with strict=False \n' f'Num matches: {len([k for k in net.state_dict() if k in state_dict])}\n' f'Num missing: {len(incompatible.missing_keys)} \n' f'Num unexpected: {len(incompatible.unexpected_keys)}') for p in net.parameters(): p.requires_grad = False # net = Compose(nn.GroupNorm(32, 32, affine=False), net) return net
def _load_encoder(encoder_path): net = TaskonomyEncoder() #.cuda() net.eval() checkpoint = torch.load(encoder_path) net.load_state_dict(checkpoint['state_dict']) for p in net.parameters(): p.requires_grad = False # net = Compose(nn.GroupNorm(32, 32, affine=False), net) return net