Exemple #1
0
    def init_distributed_data_parallel_model(self):
        """
        This method overloads the ClassificationTask class's method from ClassyVision.
        """
        if not is_distributed_training_run():
            return

        assert get_cuda_device_index(
        ) > -1, "Distributed training not setup correctly"

        # TODO (Min): We can load checkpoint, but it ends up setting the trunk's _is_root
        # flag to true. We need to set it back to None here.
        # Also, right now, the head's weight is only partially loaded from the checkpoint
        # because we dump the checkpoint after the head if wrapped, but loading it before
        # it is wrapped.
        # For very big models, we need re-work the checkpoint logic because we don't have
        # enough memory to load the entire model on one node. We need to use local_state_dict()
        # API to load checkpoint shards.
        for module in self.base_model.trunk.modules():
            if isinstance(module, FSDP):
                module._is_root = None

        # Then, wrap the whole model. We replace the base_model since it is used
        # when checkpoint is taken.
        fsdp_config = self.config["MODEL"]["FSDP_CONFIG"]
        self.base_model = FSDP(module=self.base_model, **fsdp_config)
        self.distributed_model = self.base_model
Exemple #2
0
 def __init__(self, loss_config: AttrDict):
     super().__init__()
     self.loss_config = loss_config
     self.momentum_teacher = None
     self.checkpoint = None
     self.teacher_output = None
     self.teacher_temp = None
     self.is_distributed = is_distributed_training_run()
     self.use_gpu = get_cuda_device_index() > -1
     self.center = None
Exemple #3
0
    def __init__(self, buffer_params, temperature: float):
        super(SimclrInfoNCECriterion, self).__init__()

        self.use_gpu = get_cuda_device_index() > -1
        self.temperature = temperature
        self.num_pos = 2
        self.buffer_params = buffer_params
        self.criterion = nn.CrossEntropyLoss()
        self.dist_rank = get_rank()
        self.pos_mask = None
        self.neg_mask = None
        self.precompute_pos_neg_mask()
        logging.info(f"Creating Info-NCE loss on Rank: {self.dist_rank}")
Exemple #4
0
    def init_distributed_data_parallel_model(self):
        """
        Initialize FSDP if needed.

        This method overloads the ClassificationTask class's method from ClassyVision.
        """
        if not is_distributed_training_run():
            return

        # Make sure default cuda device is set. TODO (Min): we should enable FSDP can
        # be enabled for 1-GPU as well, but the use case there is likely different.
        # I.e. perhaps we use it for cpu_offloading.
        assert get_cuda_device_index(
        ) > -1, "Distributed training not setup correctly"

        # The model might be already wrapped by FSDP internally. Check regnet_fsdp.py.
        # Here, we wrap it at the outer most level.
        fsdp_config = self.config["MODEL"]["FSDP_CONFIG"]
        if is_primary():
            logging.info(f"Using FSDP, config: {fsdp_config}")

        # First, wrap the head's prototype_i layers if it is SWAV.
        # TODO (Min): make this more general for different models, which may have multiple
        #             heads.
        head0 = self.base_model.heads[0]
        if isinstance(head0, SwAVPrototypesHead):
            for j in range(head0.nmb_heads):
                module = getattr(head0, "prototypes" + str(j))
                module = FSDP(module=module, **fsdp_config)
                setattr(head0, "prototypes" + str(j), module)

        # TODO (Min): We can load checkpoint, but it ends up setting the trunk's _is_root
        # flag to true. We need to set it back to None here.
        # Also, right now, the head's weight is only partially loaded from the checkpoint
        # because we dump the checkpoint after the head if wrapped, but loading it before
        # it is wrapped.
        # For very big models, we need re-work the checkpoint logic because we don't have
        # enough memory to load the entire model on one node. We need to use local_state_dict()
        # API to load checkpoint shards.
        for module in self.base_model.trunk.modules():
            if isinstance(module, FSDP):
                module._is_root = None

        # Then, wrap the whole model. We replace the base_model since it is used
        # when checkpoint is taken.
        self.base_model = FSDP(module=self.base_model, **fsdp_config)
        self.distributed_model = self.base_model
Exemple #5
0
    def __init__(self, loss_config: AttrDict):
        super().__init__()
        self.loss_config = loss_config

        self.momentum_encoder = None
        self.checkpoint = None
        self.momentum_scores = None
        self.momentum_embeddings = None
        self.is_distributed = is_distributed_training_run()
        self.use_gpu = get_cuda_device_index() > -1
        self.softmax = nn.Softmax(dim=1)

        # keep track of number of iterations
        self.register_buffer("num_iteration", torch.zeros(1, dtype=int))

        # for queue
        self.use_queue = False
        if self.loss_config.queue.local_queue_length > 0:
            self.initialize_queue()
Exemple #6
0
    def __init__(
        self,
        temperature: float,
        crops_for_assign: List[int],
        num_crops: int,
        num_iters: int,
        epsilon: float,
        use_double_prec: bool,
        num_prototypes: List[int],
        local_queue_length: int,
        embedding_dim: int,
        temp_hard_assignment_iters: int,
        output_dir: str,
    ):
        super(SwAVCriterion, self).__init__()

        self.use_gpu = get_cuda_device_index() > -1

        self.temperature = temperature
        self.crops_for_assign = crops_for_assign
        self.num_crops = num_crops
        self.nmb_sinkhornknopp_iters = num_iters
        self.epsilon = epsilon
        self.use_double_prec = use_double_prec
        self.num_prototypes = num_prototypes
        self.nmb_heads = len(self.num_prototypes)
        self.embedding_dim = embedding_dim
        self.temp_hard_assignment_iters = temp_hard_assignment_iters
        self.local_queue_length = local_queue_length
        self.dist_rank = get_rank()
        self.world_size = get_world_size()
        self.log_softmax = nn.LogSoftmax(dim=1).cuda()
        self.softmax = nn.Softmax(dim=1).cuda()
        self.register_buffer("num_iteration", torch.zeros(1, dtype=int))
        self.use_queue = False
        if local_queue_length > 0:
            self.initialize_queue()
        self.output_dir = output_dir
Exemple #7
0
    def init_distributed_data_parallel_model(self):
        """
        Initialize FSDP if needed.

        This method overloads the ClassificationTask class's method from ClassyVision.
        """
        if not is_distributed_training_run():
            return

        # Make sure default cuda device is set. TODO (Min): we should ensure FSDP can
        # be enabled for 1-GPU as well, but the use case there is likely different.
        # I.e. perhaps we use it for cpu_offloading.
        assert get_cuda_device_index(
        ) > -1, "Distributed training not setup correctly"

        # The model might be already wrapped by FSDP internally. Check regnet_fsdp.py.
        # Here, we wrap it at the outer most level.
        fsdp_config = self.config["MODEL"]["FSDP_CONFIG"]
        if is_primary():
            logging.info(f"Using FSDP, config: {fsdp_config}")

        # First, wrap the head's prototype_i layers if it is SWAV.
        # TODO (Min): make this more general for different models, which may have multiple
        #             heads.
        if len(self.base_model.heads) != 1:
            raise ValueError(
                f"FSDP only support 1 head, not {len(self.base_model.heads)} heads"
            )
        head0 = self.base_model.heads[0]
        if isinstance(head0, SwAVPrototypesHead):
            # This is important for convergence!
            #
            # Since we "normalize" this layer in the update hook, we need to keep its
            # weights in full precision. It is output is going into the loss and used
            # for clustering, so we need to have that in full precision as well.
            fp_fsdp_config = fsdp_config.copy()
            fp_fsdp_config["flatten_parameters"] = False
            fp_fsdp_config["mixed_precision"] = False
            fp_fsdp_config["fp32_reduce_scatter"] = False
            for j in range(head0.nmb_heads):
                module = getattr(head0, "prototypes" + str(j))
                module = FSDP(module=module, **fp_fsdp_config)
                setattr(head0, "prototypes" + str(j), module)
        head0 = FSDP(module=head0, **fsdp_config)
        self.base_model.heads[0] = head0

        # Init the head properly since the weights are potentially initialized on different
        # ranks with different seeds. We first summon the full params from all workers.
        # Then, within that context, we set a fixed random seed so that all workers init the
        # weights the same way. Finally, we reset the layer's weights using reset_parameters().
        #
        # TODO (Min): This will go away once we have a way to sync from rank 0.
        with head0.summon_full_params():
            with set_torch_seed(self.config["SEED_VALUE"]):
                for m in head0.modules():
                    if isinstance(m, Linear):
                        m.reset_parameters()
        head0._reset_lazy_init()
        head0.prototypes0._reset_lazy_init()

        # TODO (Min): We can load checkpoint, but it ends up setting the trunk's _is_root
        # flag to true. We need to set it back to None here.
        # Also, right now, the head's weight is only partially loaded from the checkpoint
        # because we dump the checkpoint after the head if wrapped, but loading it before
        # it is wrapped.
        # For very big models, we need re-work the checkpoint logic because we don't have
        # enough memory to load the entire model on one node. We need to use local_state_dict()
        # API to load checkpoint shards.
        for module in self.base_model.trunk.modules():
            if isinstance(module, FSDP):
                module._is_root = None

        # Then, wrap the whole model. We replace the base_model since it is used
        # when checkpoint is taken.
        self.base_model = FSDP(module=self.base_model, **fsdp_config)
        self.distributed_model = self.base_model