Ejemplo n.º 1
0
    def resolve_training_type_plugin(
            self, training_type: TrainingTypePlugin) -> TrainingTypePlugin:
        # necessary for when the user has passed in a plugin
        if hasattr(training_type, 'parallel_devices') and not getattr(
                training_type, 'parallel_devices'):
            training_type.parallel_devices = self.parallel_devices
            if hasattr(training_type, 'num_processes'):
                training_type.num_processes = len(self.parallel_devices)

        if hasattr(training_type, 'cluster_environment') and getattr(
                training_type, 'cluster_environment') is None:
            training_type.cluster_environment = self.select_cluster_environment(
            )

        if hasattr(
                training_type,
                'num_nodes') and getattr(training_type, 'num_nodes') is None:
            training_type.num_nodes = self.num_nodes

        # Automatically set sync_batchnorm if None.
        # Useful for custom plugins.
        if hasattr(training_type, 'sync_batchnorm') and getattr(
                training_type, 'sync_batchnorm') is None:
            training_type.sync_batchnorm = self.sync_batchnorm

        return training_type
    def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> TrainingTypePlugin:
        # necessary for when the user has passed in a plugin
        if hasattr(training_type, 'parallel_devices') and not getattr(training_type, 'parallel_devices'):
            training_type.parallel_devices = self.parallel_devices
            if hasattr(training_type, 'num_processes'):
                training_type.num_processes = len(self.parallel_devices)

        if hasattr(training_type, 'cluster_environment') and getattr(training_type, 'cluster_environment') is None:
            training_type.cluster_environment = self.select_cluster_environment()

        if hasattr(training_type, 'num_nodes') and getattr(training_type, 'num_nodes') is None:
            training_type.num_nodes = self.num_nodes

        return training_type
Ejemplo n.º 3
0
    def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> TrainingTypePlugin:
        # necessary for when the user has passed in a plugin
        if hasattr(training_type, 'parallel_devices') and getattr(training_type, 'parallel_devices') is None:
            training_type.parallel_devices = self.parallel_devices
            if hasattr(training_type, 'num_processes'):
                training_type.num_processes = len(self.parallel_devices)

        if hasattr(training_type, 'cluster_environment') and getattr(training_type, 'cluster_environment') is None:
            training_type.cluster_environment = self.select_cluster_environment()

        if hasattr(training_type, 'num_nodes'):
            # set num_nodes for training_type from trainer setting
            training_type.num_nodes = self.num_nodes

        if hasattr(training_type, 'sync_batchnorm'):
            # set sync_batchnorm for training_type from trainer setting
            training_type.sync_batchnorm = self.sync_batchnorm

        return training_type
    def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> TrainingTypePlugin:
        # necessary for when the user has passed in a plugin
        if hasattr(training_type, "parallel_devices") and getattr(training_type, "parallel_devices") is None:
            training_type.parallel_devices = self.parallel_devices
            if hasattr(training_type, "num_processes"):
                training_type.num_processes = len(self.parallel_devices)

        if hasattr(training_type, "cluster_environment") and getattr(training_type, "cluster_environment") is None:
            # transfer ownership of the cluster environment to the training type
            training_type.cluster_environment = self.cluster_environment
            self._cluster_environment = proxy(self.cluster_environment)

        if hasattr(training_type, "num_nodes"):
            # set num_nodes for training_type from trainer setting
            training_type.num_nodes = self.num_nodes

        if hasattr(training_type, "sync_batchnorm"):
            # set sync_batchnorm for training_type from trainer setting
            training_type.sync_batchnorm = self.sync_batchnorm

        return training_type