示例#1
0
文件: trainer.py 项目: sxjscience/lvt
 def __init__(self, cfg):
     """
     Args:
         cfg (CfgNode):
     """
     logger = logging.getLogger("vidgen")
     if not logger.isEnabledFor(
             logging.INFO):  # setup_logger is not called for d2
         setup_logger()
     # Assume these objects must be constructed in this order.
     self.model = self.build_model(cfg)
     self.optimizers, self.checkpointers = self.model.configure_optimizers_and_checkpointers(
     )
     self.data_loader, dataset_len = self.build_train_loader(cfg)
     self._data_loader_iter = iter(self.data_loader)
     # For training, wrap with DDP. But don't need this for inference.
     if comm.get_world_size() > 1:
         self.model.wrap_parallel(device_ids=[comm.get_local_rank()],
                                  broadcast_buffers=False)
     super().__init__(cfg)
     self.model.train()
     self.start_iter = 0
     self.max_iter = cfg.SOLVER.MAX_ITER
     self.gan_mode_on = cfg.GAN_MODE_ON
     self.supervised_max_iter = cfg.SOLVER.SUPERVISED_MAX_ITER
     self.d_update_ratio = cfg.SOLVER.D_UPDATE_RATIO
     self.d_init_iters = cfg.SOLVER.D_INIT_ITERS
     self.cfg = cfg
     self.register_hooks(self.build_hooks())
     self.accumulation_steps = cfg.SOLVER.ACCUMULATION_STEPS
示例#2
0
    def _straight_through(self, z_e_x):
        z_e_x_ = z_e_x.permute(0, 2, 3, 1).contiguous()
        z_q_x_, indices = vq_st(z_e_x_, self.embedding.weight.detach())
        z_q_x = z_q_x_.permute(0, 3, 1, 2).contiguous()

        if self.ema:
            # Use EMA to update the embedding vectors
            with torch.no_grad():
                device = indices.device
                size = torch.zeros_like(self.running_size,
                                        dtype=indices.dtype,
                                        device=device)
                size.index_add_(dim=0,
                                index=indices,
                                source=torch.ones_like(indices, device=device))
                if comm.get_world_size() > 1:
                    size = AllReduce.apply(size)
                self.running_size.data.mul_(self.decay).add_(
                    1 - self.decay, size)

                sum = torch.zeros_like(self.running_sum,
                                       dtype=z_e_x_.dtype,
                                       device=device)
                b, h, w, c = z_e_x_.size()
                sum.index_add_(dim=0,
                               index=indices,
                               source=z_e_x_.view(b * h * w, c))
                if comm.get_world_size() > 1:
                    sum = AllReduce.apply(sum)
                self.running_sum.data.mul_(self.decay).add_(
                    1 - self.decay, sum)

                n = self.running_size.sum()
                size_ = (self.running_size +
                         self.eps) / (n + self.K * self.eps) * n
                self.embedding.weight.data.copy_(self.running_sum /
                                                 size_.unsqueeze(1))

        z_q_x_bar_flatten = torch.index_select(self.embedding.weight,
                                               dim=0,
                                               index=indices)
        z_q_x_bar_ = z_q_x_bar_flatten.view_as(z_e_x_)
        z_q_x_bar = z_q_x_bar_.permute(0, 3, 1, 2).contiguous()

        return z_q_x, z_q_x_bar
示例#3
0
def default_setup(cfg, args):
    """
    Perform some basic common setups at the beginning of a job, including:

    1. Set up the vidgen logger
    2. Log basic information about environment, cmdline arguments, and config
    3. Backup the config to the output directory

    Args:
        cfg (CfgNode): the full config to be used
        args (argparse.NameSpace): the command line arguments to be logged
    """
    output_dir = cfg.OUTPUT_DIR
    if comm.is_main_process() and output_dir:
        PathManager.mkdirs(output_dir)

    rank = comm.get_rank()
    setup_logger(output_dir, distributed_rank=rank, name="fvcore")
    logger = setup_logger(output_dir, distributed_rank=rank)

    logger.info("Rank of current process: {}. World size: {}".format(
        rank, comm.get_world_size()))
    logger.info("Environment info:\n" + collect_env_info())

    logger.info("Command line arguments: " + str(args))
    if hasattr(args, "config_file") and args.config_file != "":
        logger.info("Contents of args.config_file={}:\n{}".format(
            args.config_file,
            PathManager.open(args.config_file, "r").read()))

    logger.info("Running with full config:\n{}".format(cfg))
    if comm.is_main_process() and output_dir:
        # Note: some of our scripts may expect the existence of
        # config.yaml in output directory
        path = os.path.join(output_dir, "config.yaml")
        with PathManager.open(path, "w") as f:
            f.write(cfg.dump())
        logger.info("Full config saved to {}".format(os.path.abspath(path)))

    # make sure each worker has a different, yet deterministic seed if specified
    seed_all_rng(None if cfg.SEED < 0 else cfg.SEED + rank)

    # cudnn benchmark has large overhead. It shouldn't be used considering the small size of
    # typical validation set.
    if not (hasattr(args, "eval_only") and args.eval_only):
        torch.backends.cudnn.benchmark = cfg.CUDNN_BENCHMARK

        # I don't understand fully this line, but it helps with large batch sizes and reduces data loading time
        torch.multiprocessing.set_sharing_strategy('file_system')
示例#4
0
    def __init__(self, size: int, n_samples=0):
        """
        Args:
            size (int): the total number of data of the underlying dataset to sample from
        """
        self._size = size
        assert size > 0
        self._rank = comm.get_rank()
        self._world_size = comm.get_world_size()

        shard_size = (self._size - 1) // self._world_size + 1
        begin = shard_size * self._rank
        end = min(shard_size * (self._rank + 1), self._size)
        self._local_indices = range(begin, end)
        if n_samples > 0:
            self._local_indices = np.random.choice(self._local_indices, n_samples, replace=False)
示例#5
0
    def forward(self, input):
        if comm.get_world_size() == 1 or not self.training:
            return super().forward(input)

        B, C = input.shape[0], input.shape[1]

        mean = torch.mean(input, dim=[0, 2, 3])
        meansqr = torch.mean(input * input, dim=[0, 2, 3])

        if self._stats_mode == "":
            assert B > 0, 'SyncBatchNorm(stats_mode="") does not support zero batch size.'
            vec = torch.cat([mean, meansqr], dim=0)
            vec = AllReduce.apply(vec) * (1.0 / dist.get_world_size())
            mean, meansqr = torch.split(vec, C)
            momentum = self.momentum
        else:
            if B == 0:
                vec = torch.zeros([2 * C + 1],
                                  device=mean.device,
                                  dtype=mean.dtype)
                vec = vec + input.sum(
                )  # make sure there is gradient w.r.t input
            else:
                vec = torch.cat([
                    mean, meansqr,
                    torch.ones([1], device=mean.device, dtype=mean.dtype)
                ],
                                dim=0)
            vec = AllReduce.apply(vec * B)

            total_batch = vec[-1].detach()
            momentum = total_batch.clamp(
                max=1) * self.momentum  # no update if total_batch is 0
            total_batch = torch.max(
                total_batch, torch.ones_like(total_batch))  # avoid div-by-zero
            mean, meansqr, _ = torch.split(vec / total_batch, C)

        var = meansqr - mean * mean
        invstd = torch.rsqrt(var + self.eps)
        scale = self.weight * invstd
        bias = self.bias - mean * scale
        scale = scale.reshape(1, -1, 1, 1)
        bias = bias.reshape(1, -1, 1, 1)

        self.running_mean += momentum * (mean.detach() - self.running_mean)
        self.running_var += momentum * (var.detach() - self.running_var)
        return input * scale + bias
示例#6
0
    def __init__(self, size: int, shuffle: bool = True, seed: Optional[int] = None):
        """
        Args:
            size (int): the total number of data of the underlying dataset to sample from
            shuffle (bool): whether to shuffle the indices or not
            seed (int): the initial seed of the shuffle. Must be the same
                across all workers. If None, will use a random seed shared
                among workers (require synchronization among all workers).
        """
        self._size = size
        assert size > 0
        self._shuffle = shuffle
        if seed is None:
            seed = comm.shared_random_seed()
        self._seed = int(seed)

        self._rank = comm.get_rank()
        self._world_size = comm.get_world_size()
示例#7
0
    def __init__(self, dataset_dicts, repeat_thresh, shuffle=True, seed=None):
        """
        Args:
            dataset_dicts (list[dict]): annotations in vidgen dataset format.
            repeat_thresh (float): frequency threshold below which data is repeated.
            shuffle (bool): whether to shuffle the indices or not
            seed (int): the initial seed of the shuffle. Must be the same
                across all workers. If None, will use a random seed shared
                among workers (require synchronization among all workers).
        """
        self._shuffle = shuffle
        if seed is None:
            seed = comm.shared_random_seed()
        self._seed = int(seed)

        self._rank = comm.get_rank()
        self._world_size = comm.get_world_size()

        # Get fractional repeat factors and split into whole number (_int_part)
        # and fractional (_frac_part) parts.
        rep_factors = self._get_repeat_factors(dataset_dicts, repeat_thresh)
        self._int_part = torch.trunc(rep_factors)
        self._frac_part = rep_factors - self._int_part
示例#8
0
文件: build.py 项目: sxjscience/lvt
def build_train_loader(cfg, mapper=None):
    """
    A data loader is created by the following steps:

    1. Use the dataset names in config to query :class:`DatasetCatalog`, and obtain a list of dicts.
    2. Start workers to work on the dicts. Each worker will:

       * Map each metadata dict into another format to be consumed by the model.
       * Batch them by simply putting dicts into a list.

    The batched ``list[mapped_dict]`` is what this dataloader will return.

    Args:
        cfg (CfgNode): the config
        mapper (callable): a callable which takes a sample (dict) from dataset and
            returns the format to be consumed by the model.
            By default it will be `DatasetMapper(cfg, True)`.

    Returns:
        an infinite iterator of training data
    """
    num_workers = get_world_size()
    images_per_batch = cfg.SOLVER.IMS_PER_BATCH
    assert (
            images_per_batch % num_workers == 0
    ), "SOLVER.IMS_PER_BATCH ({}) must be divisible by the number of workers ({}).".format(
        images_per_batch, num_workers
    )
    assert (
            images_per_batch >= num_workers
    ), "SOLVER.IMS_PER_BATCH ({}) must be larger than the number of workers ({}).".format(
        images_per_batch, num_workers
    )
    images_per_worker = images_per_batch // num_workers

    dataset_dicts = get_dataset_dicts(cfg.DATASETS.TRAIN)
    dataset = DatasetFromList(dataset_dicts, copy=False)

    if mapper is None:
        mapper = DatasetMapper(cfg, True)
    dataset = MapDataset(dataset, mapper)

    sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
    logger = logging.getLogger(__name__)
    logger.info("Using training sampler {}".format(sampler_name))
    if sampler_name == "TrainingSampler":
        sampler = samplers.TrainingSampler(len(dataset))
    elif sampler_name == "RepeatFactorTrainingSampler":
        sampler = samplers.RepeatFactorTrainingSampler(
            dataset_dicts, cfg.DATALOADER.REPEAT_THRESHOLD
        )
    else:
        raise ValueError("Unknown training sampler: {}".format(sampler_name))

    batch_sampler = torch.utils.data.sampler.BatchSampler(
        sampler, images_per_worker, drop_last=True
    )
    # drop_last so the batch always have the same size
    data_loader = torch.utils.data.DataLoader(
        dataset,
        num_workers=cfg.DATALOADER.NUM_WORKERS,
        batch_sampler=batch_sampler,
        collate_fn=trivial_batch_collator,
        worker_init_fn=worker_init_reset_seed,
    )

    return data_loader, len(dataset)