def test_training_only_cnn_weights(self): with initialize(config_path=os.path.join("..", "fixtures", "conf")): config = compose( config_name="config", overrides=[ "runner/model=resnet18fusionpolicy", "runner/dataset=filetextbertforwardsatosapolicy", "runner/val_dataset=filetextbertforwardsatosapolicy", "runner.runner_name=SOCTextForwardPolicyRunner" ]) config.runner.dataset.dataset_path = _RAW_TEXT_BERT_DATASET_PATH config.runner.val_dataset.dataset_path = _RAW_TEXT_BERT_DATASET_PATH config.runner.train_cnn = True config.runner.train_fusion = False config.runner.train_heads = True config.trainer.default_root_dir = self.folder config.trainer.fast_dev_run = False # We rely on seeds to copy the init weights seed_everything(config['runner']['seed']) r_copy = make_runner(config['runner']) r_copy.setup('fit') seed_everything(config['runner']['seed']) runner = make_runner(config['runner']) trainer = Trainer(**config['trainer'], deterministic=True) trainer.fit(runner) zipped_params = zip(r_copy.model.fusion.parameters(), runner.model.fusion.parameters()) for param_copy, param in zipped_params: assert torch.all(torch.eq(param_copy, param)) zipped_params = zip(r_copy.model.spatial_state_head.parameters(), runner.model.spatial_state_head.parameters()) for param_copy, param in zipped_params: assert not torch.all(torch.eq(param_copy, param)) zipped_params = zip(r_copy.model.linear_state_head.parameters(), runner.model.linear_state_head.parameters()) for param_copy, param in zipped_params: assert not torch.all(torch.eq(param_copy, param)) # zipped_params = zip( # r_copy.model.policy_head.parameters(), runner.model.policy_head.parameters() # ) # for param_copy, param in zipped_params: # assert not torch.all(torch.eq(param_copy, param)) zipped_params = zip(r_copy.model.cnn.parameters(), runner.model.cnn.parameters()) for param_copy, param in zipped_params: assert not torch.all(torch.eq(param_copy, param)) break # Not all layers are learnable so we check only the first one
def test_training_soc_preprocessed_seq_convlstmpolicy(self): with initialize(config_path=os.path.join(".", "fixtures", "conf")): config = compose(config_name="config", overrides=[ "runner/model=convlstmpolicy", "runner/dataset=preprocessedseqsatosapolicy", "runner.runner_name=SOCSeqPolicyRunner" ]) config.runner.dataset.dataset_path = _DATASET_PATH config.trainer.default_root_dir = self.folder seed_everything(config['runner']['seed']) runner = make_runner(config['runner']) trainer = Trainer(**config['trainer'], deterministic=True) trainer.fit(runner)
def test_training_soc_psql_seq_sas_conv3d(self): with initialize(config_path=os.path.join(".", "fixtures", "conf")): config = compose(config_name="config", overrides=[ "runner/model=conv3d", "runner/dataset=psqlseqsatos", "runner.runner_name=SOCSupervisedSeqRunner" ]) config.trainer.default_root_dir = self.folder seed_everything(config['runner']['seed']) runner = make_runner(config['runner']) runner.setup_dataset = self.setup_dataset trainer = Trainer(**config['trainer'], deterministic=True) trainer.fit(runner)
def test_training_soc_preprocessed_forward_resnet(self): with initialize(config_path=os.path.join(".", "fixtures", "conf")): config = compose( config_name="config", overrides=[ "runner/model=resnet18", "runner/dataset=preprocessedforwardsatosa", "runner.runner_name=SOCSupervisedForwardRunner" ]) config.runner.dataset.dataset_path = _DATASET_PATH config.trainer.default_root_dir = self.folder seed_everything(config['runner']['seed']) runner = make_runner(config['runner']) trainer = Trainer(**config['trainer'], deterministic=True) trainer.fit(runner)
def test_training_soc_file_humantrade_forward_resnetmeanffpolicy(self): with initialize(config_path=os.path.join(".", "fixtures", "conf")): config = compose( config_name="config", overrides=[ "runner/model=resnet18meanffpolicy", "runner/dataset=filetextberthumantradeforwardsatosapolicy", "runner.runner_name=SOCTextForwardPolicyRunner" ]) config.runner.dataset.dataset_path = _RAW_TEXT_BERT_DATASET_PATH config.trainer.default_root_dir = self.folder seed_everything(config['runner']['seed']) runner = make_runner(config['runner']) runner.num_workers = 1 trainer = Trainer(**config['trainer'], deterministic=True) trainer.fit(runner)
def test_training_soc_psql_forward_resnetmeanconcatpolicy(self): with initialize(config_path=os.path.join(".", "fixtures", "conf")): config = compose( config_name="config", overrides=[ "runner/model=resnet18meanconcatpolicy", "runner/dataset=psqltextbertforwardsatosapolicy", "runner.runner_name=SOCTextForwardPolicyRunner" ]) config.trainer.default_root_dir = self.folder seed_everything(config['runner']['seed']) runner = make_runner(config['runner']) runner.setup_dataset = self.setup_text_dataset runner.num_workers = 1 trainer = Trainer(**config['trainer'], deterministic=True) trainer.fit(runner)
'hyper_parameters']['dataset']['use_pooler_features'] ckpt['hyper_parameters']['val_dataset']['set_empty_text_to_zero'] = ckpt[ 'hyper_parameters']['dataset']['set_empty_text_to_zero'] ckpt['hyper_parameters']['model'] = DictConfig( ckpt['hyper_parameters']['model']) ckpt['hyper_parameters']['dataset'][ 'dataset_path'] = _RAW_SOC5_TEXT_BERT_DATASET_PATH ckpt['hyper_parameters']['val_dataset'][ 'dataset_path'] = _RAW_SOC5_TEXT_BERT_DATASET_PATH ckpt['hyper_parameters']['dataset']['shuffle'] = False ckpt['hyper_parameters']['batch_size'] = 1 pl.seed_everything(ckpt['hyper_parameters']['seed']) runner = make_runner(ckpt['hyper_parameters']) runner.setup('fit') runner.load_state_dict(ckpt['state_dict']) runner.eval() pprint.pprint(runner.hparams) def format_res_tensor(res_tensor): return res_tensor.view(4, 6).detach().numpy().tolist() def format_res_pred(res_pred): return format_res_tensor(ds_utils.unnormalize_playersresources(res_pred))