def __init__(self, nusc, is_train, cfg): self.nusc = nusc self.is_train = is_train self.cfg = cfg self.is_lyft = isinstance(nusc, LyftDataset) if self.is_lyft: self.dataroot = self.nusc.data_path else: self.dataroot = self.nusc.dataroot self.mode = 'train' if self.is_train else 'val' self.sequence_length = cfg.TIME_RECEPTIVE_FIELD + cfg.N_FUTURE_FRAMES self.scenes = self.get_scenes() self.ixes = self.prepro() self.indices = self.get_indices() # Image resizing and cropping self.augmentation_parameters = self.get_resizing_and_cropping_parameters( ) # Normalising input images self.normalise_image = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) # Bird's-eye view parameters bev_resolution, bev_start_position, bev_dimension = calculate_birds_eye_view_parameters( cfg.LIFT.X_BOUND, cfg.LIFT.Y_BOUND, cfg.LIFT.Z_BOUND) self.bev_resolution, self.bev_start_position, self.bev_dimension = ( bev_resolution.numpy(), bev_start_position.numpy(), bev_dimension.numpy()) # Spatial extent in bird's-eye view, in meters self.spatial_extent = (self.cfg.LIFT.X_BOUND[1], self.cfg.LIFT.Y_BOUND[1])
def __init__(self, cfg): super().__init__() self.cfg = cfg bev_resolution, bev_start_position, bev_dimension = calculate_birds_eye_view_parameters( self.cfg.LIFT.X_BOUND, self.cfg.LIFT.Y_BOUND, self.cfg.LIFT.Z_BOUND) self.bev_resolution = nn.Parameter(bev_resolution, requires_grad=False) self.bev_start_position = nn.Parameter(bev_start_position, requires_grad=False) self.bev_dimension = nn.Parameter(bev_dimension, requires_grad=False) self.encoder_downsample = self.cfg.MODEL.ENCODER.DOWNSAMPLE self.encoder_out_channels = self.cfg.MODEL.ENCODER.OUT_CHANNELS self.frustum = self.create_frustum() self.depth_channels, _, _, _ = self.frustum.shape if self.cfg.TIME_RECEPTIVE_FIELD == 1: assert self.cfg.MODEL.TEMPORAL_MODEL.NAME == 'identity' # temporal block self.receptive_field = self.cfg.TIME_RECEPTIVE_FIELD self.n_future = self.cfg.N_FUTURE_FRAMES self.latent_dim = self.cfg.MODEL.DISTRIBUTION.LATENT_DIM if self.cfg.MODEL.SUBSAMPLE: assert self.cfg.DATASET.NAME == 'lyft' self.receptive_field = 3 self.n_future = 5 # Spatial extent in bird's-eye view, in meters self.spatial_extent = (self.cfg.LIFT.X_BOUND[1], self.cfg.LIFT.Y_BOUND[1]) self.bev_size = (self.bev_dimension[0].item(), self.bev_dimension[1].item()) # Encoder self.encoder = Encoder(cfg=self.cfg.MODEL.ENCODER, D=self.depth_channels) # Temporal model temporal_in_channels = self.encoder_out_channels if self.cfg.MODEL.TEMPORAL_MODEL.INPUT_EGOPOSE: temporal_in_channels += 6 if self.cfg.MODEL.TEMPORAL_MODEL.NAME == 'identity': self.temporal_model = TemporalModelIdentity( temporal_in_channels, self.receptive_field) elif cfg.MODEL.TEMPORAL_MODEL.NAME == 'temporal_block': self.temporal_model = TemporalModel( temporal_in_channels, self.receptive_field, input_shape=self.bev_size, start_out_channels=self.cfg.MODEL.TEMPORAL_MODEL. START_OUT_CHANNELS, extra_in_channels=self.cfg.MODEL.TEMPORAL_MODEL. EXTRA_IN_CHANNELS, n_spatial_layers_between_temporal_layers=self.cfg.MODEL. TEMPORAL_MODEL.INBETWEEN_LAYERS, use_pyramid_pooling=self.cfg.MODEL.TEMPORAL_MODEL. PYRAMID_POOLING, ) else: raise NotImplementedError( f'Temporal module {self.cfg.MODEL.TEMPORAL_MODEL.NAME}.') self.future_pred_in_channels = self.temporal_model.out_channels if self.n_future > 0: # probabilistic sampling if self.cfg.PROBABILISTIC.ENABLED: # Distribution networks self.present_distribution = DistributionModule( self.future_pred_in_channels, self.latent_dim, min_log_sigma=self.cfg.MODEL.DISTRIBUTION.MIN_LOG_SIGMA, max_log_sigma=self.cfg.MODEL.DISTRIBUTION.MAX_LOG_SIGMA, ) future_distribution_in_channels = ( self.future_pred_in_channels + self.n_future * self.cfg.PROBABILISTIC.FUTURE_DIM) self.future_distribution = DistributionModule( future_distribution_in_channels, self.latent_dim, min_log_sigma=self.cfg.MODEL.DISTRIBUTION.MIN_LOG_SIGMA, max_log_sigma=self.cfg.MODEL.DISTRIBUTION.MAX_LOG_SIGMA, ) # Future prediction self.future_prediction = FuturePrediction( in_channels=self.future_pred_in_channels, latent_dim=self.latent_dim, n_gru_blocks=self.cfg.MODEL.FUTURE_PRED.N_GRU_BLOCKS, n_res_layers=self.cfg.MODEL.FUTURE_PRED.N_RES_LAYERS, ) # Decoder self.decoder = Decoder( in_channels=self.future_pred_in_channels, n_classes=len(self.cfg.SEMANTIC_SEG.WEIGHTS), predict_future_flow=self.cfg.INSTANCE_FLOW.ENABLED, ) set_bn_momentum(self, self.cfg.MODEL.BN_MOMENTUM)