def __post_init__(self): # Fetching gesture mappings with open('gestop/data/static_gesture_mapping.json', 'r') as jsonfile: self.static_gesture_mapping = json.load(jsonfile) with open('gestop/data/dynamic_gesture_mapping.json', 'r') as jsonfile: self.dynamic_gesture_mapping = json.load(jsonfile) with open(self.config_path, 'r') as jsonfile: self.gesture_action_mapping = json.load(jsonfile) self.mouse = Controller() self.user_config = UserConfig() # Setting up networks if not self.lite: logging.info('Loading GestureNet...') self.gesture_net = GestureNet(self.static_input_dim, self.static_output_classes, self.static_gesture_mapping) self.gesture_net.load_state_dict( torch.load(self.static_path, map_location=self.map_location)) self.gesture_net.eval() logging.info('Loading ShrecNet..') self.shrec_net = ShrecNet(self.dynamic_input_dim, self.dynamic_output_classes, self.dynamic_gesture_mapping) self.shrec_net.load_state_dict( torch.load(self.dynamic_path, map_location=self.map_location)) self.shrec_net.eval()
class Config: ''' The configuration of the application. ''' if not torch.cuda.is_available(): map_location = torch.device('cpu') else: map_location = None if not os.path.exists('gestop/logs'): os.mkdir('gestop/logs') # Set up logger logging.basicConfig( level=logging.DEBUG, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[ logging.FileHandler("gestop/logs/debug{}.log".format( datetime.datetime.now().strftime("%m.%d:%H.%M.%S"))), logging.StreamHandler(stdout) ]) # Disabled to prevent debug output by matplotlib logging.getLogger('matplotlib.font_manager').disabled = True # If lite is true, then the neural networks are not loaded into the config # This is useful in scripts which do not use the network, or may modify the network. lite: bool # Path to action configuration file config_path: str = 'gestop/data/action_config.json' # Seed value for reproducibility seed_val: int = 42 # Refer make_vector() in train_model.py to verify input dimensions static_input_dim: int = 49 static_output_classes: int = 7 # Refer format_mediapipe() in dynamic_train_model.py to verify input dimensions dynamic_input_dim: int = 36 dynamic_output_classes: int = 15 shrec_output_classes: int = 14 # Minimum number of epochs min_epochs: int = 15 static_batch_size: int = 64 dynamic_batch_size: int = 1 pretrained: bool = True # value for pytorch-lighting trainer attribute accumulate_grad_batches grad_accum: int = 2 static_gesture_mapping: dict = field(default_factory=dict) dynamic_gesture_mapping: dict = field(default_factory=dict) # Screen Resolution resolution: Tuple = get_screen_resolution() # Mapping of gestures to actions gesture_action_mapping: dict = field(default_factory=dict) static_path: str = 'gestop/models/gesture_net.pth' shrec_path: str = 'gestop/models/shrec_net.pth' dynamic_path: str = 'gestop/models/user_net.pth' gesture_net: GestureNet = field(init=False) shrec_net: ShrecNet = field(init=False) # Mouse tracking mouse: Controller = field(init=False) # How much a single scroll action should scroll scroll_unit: int = 10 # Specifying how to map webcam coordinates to the monitor coordinates. # Format - [x1,y1,x2,y2] where (x1,y1) specifies which coordinate to map to # the top left of your screen and (x2,y2) specifies which coordinate to map # to the bottom right of your screen. map_coord = [0.2, 0.2, 0.8, 0.8] # User configuration user_config: UserConfig = field(init=False) def __post_init__(self): # Fetching gesture mappings with open('gestop/data/static_gesture_mapping.json', 'r') as jsonfile: self.static_gesture_mapping = json.load(jsonfile) with open('gestop/data/dynamic_gesture_mapping.json', 'r') as jsonfile: self.dynamic_gesture_mapping = json.load(jsonfile) with open(self.config_path, 'r') as jsonfile: self.gesture_action_mapping = json.load(jsonfile) self.mouse = Controller() self.user_config = UserConfig() # Setting up networks if not self.lite: logging.info('Loading GestureNet...') self.gesture_net = GestureNet(self.static_input_dim, self.static_output_classes, self.static_gesture_mapping) self.gesture_net.load_state_dict( torch.load(self.static_path, map_location=self.map_location)) self.gesture_net.eval() logging.info('Loading ShrecNet..') self.shrec_net = ShrecNet(self.dynamic_input_dim, self.dynamic_output_classes, self.dynamic_gesture_mapping) self.shrec_net.load_state_dict( torch.load(self.dynamic_path, map_location=self.map_location)) self.shrec_net.eval()
def main(): ''' Main ''' parser = argparse.ArgumentParser(description='A program to train a neural network \ to recognize dynamic hand gestures.') parser.add_argument("--exp-name", help="The name with which to log the run.", type=str) args = parser.parse_args() C = Config(lite=True, pretrained=False) init_seed(C.seed_val) ################## # INPUT PIPELINE # ################## train_x, test_x, train_y, test_y, gesture_mapping = read_data(C.seed_val) with open('gestop/data/dynamic_gesture_mapping.json', 'w') as f: f.write(json.dumps(gesture_mapping)) # Higher order function to pass configuration as argument shrec_to_mediapipe = partial(format_shrec, C) user_to_mediapipe = partial(format_user, C) # Custom transforms to prepare data. shrec_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(normalize), transforms.Lambda(resample_and_jitter), transforms.Lambda(shrec_to_mediapipe), ]) user_transform = transforms.Compose([ transforms.ToTensor(), transforms.Lambda(torch.squeeze), transforms.Lambda(resample_and_jitter), transforms.Lambda(user_to_mediapipe), ]) train_loader = DataLoader(ShrecDataset(train_x, train_y, shrec_transform, user_transform), num_workers=10, batch_size=C.dynamic_batch_size, collate_fn=choose_collate(variable_length_collate, C)) val_loader = DataLoader(ShrecDataset(test_x, test_y, shrec_transform, user_transform), num_workers=10, batch_size=C.dynamic_batch_size, collate_fn=choose_collate(variable_length_collate, C)) ############ # TRAINING # ############ # Use pretrained SHREC model if C.pretrained: model = ShrecNet(C.dynamic_input_dim, C.shrec_output_classes, gesture_mapping) model.load_state_dict(torch.load(C.shrec_path)) model.replace_layers(C.dynamic_output_classes) else: model = ShrecNet(C.dynamic_input_dim, C.dynamic_output_classes, gesture_mapping) model.apply(init_weights) early_stopping = EarlyStopping( patience=5, verbose=True, ) # No name is given as a command line flag. if args.exp_name is None: args.exp_name = "default" wandb_logger = pl_loggers.WandbLogger(save_dir='gestop/logs/', name=args.exp_name, project='gestop') trainer = Trainer(gpus=1, deterministic=True, logger=wandb_logger, min_epochs=20, accumulate_grad_batches=C.grad_accum, early_stop_callback=early_stopping) trainer.fit(model, train_loader, val_loader) torch.save(model.state_dict(), C.dynamic_path) trainer.test(model, test_dataloaders=val_loader)