Beispiel #1
0
    def modify_config(self, config: FaceGroupingConfig, i):
        if i < 6:
            config.dataset = BigBangTheory(episode_index_test=i)
        else:
            config.dataset = Buffy(episode_index_test=i - 6)

        return super().modify_config(config, i)
Beispiel #2
0
 def __init__(self):
     """
     This is just an example class.
     It runs inference on BBT0101, but assumes we don't know the labels.
     Implement a class similar to this to run on new datasets.
     """
     self.dataset = BigBangTheory(episode_index_test=0)
     super().__init__(checkpoint_file=...)
Beispiel #3
0
 def modify_config(self, config: FaceGroupingConfig, i):
     from FGG.dataset.split_strategy import SplitEveryXFrames
     x = 10
     config.graph_builder_params["split_strategy"] = SplitEveryXFrames(x=x)
     config.pool_before_clustering = True
     config.dataset = BigBangTheory(episode_index_val=None,
                                    episode_index_train=None,
                                    episode_index_test=i)
     return (str(x), *super().modify_config(config, i))
Beispiel #4
0
 def __init__(self, ):
     """
     Trains 5 runs on BF0502 and 5 runs on BBT0101.
     """
     self.bf = Buffy(episode_index_train=1,
                     episode_index_val=None,
                     episode_index_test=1)
     self.bbt = BigBangTheory(episode_index_train=0,
                              episode_index_val=None,
                              episode_index_test=0)
     super().__init__(header=(), num_runs=10)
Beispiel #5
0
 def modify_config(self, config, i):
     config.dataset = BigBangTheory(episode_index_val=None)
     config.dataset.episode_index_train = i // 5
     config.dataset.episode_index_test = i // 5
     return super().modify_config(config, i)
Beispiel #6
0
 def __init__(self):
     self.bbt = BigBangTheory()
     super().__init__(header=(), num_runs=5)
Beispiel #7
0
    def __init__(self):
        """
        This is the central configuration file for FGG.
        You can change most settings here without touching the code.
        All data that is stored in this class is automatically serialized so you know what you did.

        The parameters here can be set via command line flags that will be read automatically.
        The flags' names correspond to the attributes defined here.
        """
        super().__init__()

        with self.argument_group("model-load-store"):
            self.model_name = f"{self.timestamp}_{self.git_hexsha[:5]}"
            self.output_folder = None
            self.statistics_file = None
            self.replay_log_file = None
            self.run_info_file = None
            self.model_save_file = None
            self.model_load_file = None
            self.store_features_file = None
            self.strict = True
            self.performance_indicator = "test-weighted clustering purity"

        with self.argument_group("task-data"):
            self.include_unknown = False
            self.dataset = BigBangTheory(episode_index_train=0,
                                         include_unknown=self.include_unknown)
            self.wcp_version = cluster_mode_prediction

        with self.argument_group("execution"):
            self.device = "cuda"
            self.seed = None
            self.store_base_path = Path(__file__).parent.parent / "runs"

        with self.argument_group("model-parameters"):
            self.model_type = FGG
            self.model_params = dict(in_feature_size=2048,
                                     downsample_feature_size=1024,
                                     gc_feature_sizes=(512, 256, 128),
                                     use_residual=True,
                                     activation=F.elu,
                                     num_edge_types=2,
                                     sparse_adjacency=False)
            self.loss_type = GraphContrastiveLoss
            self.loss_params = dict(margin=1.)
            self.graph_builder_params = dict(
                pos_edge_dropout=None,
                neg_edge_dropout=None,
                split_strategy=HeuristicSplit(),
                pair_sample_fraction=1,
                edge_between_top_fraction=0.03,
                isolates_similarity_only=True,
                weighted_edges=True,
            )

        with self.argument_group("runtime-duration"):
            self.train_epochs = 30
            self.split_disconnected_components = False

        with self.argument_group("optimizer-parameters"):
            self.optimizer_type = torch.optim.Adam
            self.optimizer_params = dict(lr=1e-4, weight_decay=0)