Exemplo n.º 1
0
def get_logger(name, log_file=None, log_level=logging.INFO):
    """Initialize and get a logger by name.

    If the logger has not been initialized, this method will initialize the
    logger by adding one or two handlers, otherwise the initialized logger will
    be directly returned. During initialization, a StreamHandler will always be
    added. If `log_file` is specified and the process rank is 0, a FileHandler
    will also be added.

    Args:
        name (str): Logger name.
        log_file (str | None): The log filename. If specified, a FileHandler
            will be added to the logger.
        log_level (int): The logger level. Note that only the process of
            rank 0 is affected, and other processes will set the level to
            "Error" thus be silent most of the time.

    Returns:
        logging.Logger: The expected logger.
    """
    logger = logging.getLogger(name)

    if name in logger_initialized:
        return logger

    # handle hierarchical names
    # e.g., logger "a" is initialized, then logger "a.b" will skip the
    # initialization since it is a child of "a".
    for logger_name in logger_initialized:
        if name.startswith(logger_name):  # child
            return logger

    # fix stream twice bug
    # while logger.handlers:
    #     logger.handlers.pop()

    stream_handler = logging.StreamHandler()
    handlers = [stream_handler]

    if is_distributed():
        rank = get_rank()
    else:
        rank = 0

    # only rank 0 will add a FileHandler
    if rank == 0 and log_file is not None:
        file_handler = logging.FileHandler(log_file, 'w')
        handlers.append(file_handler)

    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    for handler in handlers:
        handler.setFormatter(formatter)
        handler.setLevel(log_level)
        logger.addHandler(handler)

    if rank == 0:
        logger.setLevel(log_level)
    else:
        logger.setLevel(logging.ERROR)

    logger_initialized[name] = True
    return logger
Exemplo n.º 2
0
    def __init__(
        self,
        dataset,
        batch_size=1,
        drop_last=False,
        num_samples=None,
        world_size=None,
        rank=None,
        seed=None,
    ):
        r"""
        An abstract class for all sampler.

        :type dataset: `dataset`
        :param dataset: dataset to sample from.
        :type batch_size: positive integer
        :param batch_size: batch size for batch method.
        :type drop_last: bool
        :param drop_last: set ``True`` to drop the last incomplete batch,
            if the dataset size is not divisible by the batch size. If ``False`` and 
            the size of dataset is not divisible by the batch_size, then the last batch will
            be smaller. Default: False
        :type num_samples: positive integer
        :param num_samples: number of samples assigned to one rank.
        :type world_size: positive integer
        :param world_size: number of ranks.
        :type rank: non-negative integer within 0 and world_size
        :param rank: rank id, non-negative interger within 0 and ``world_size``.
        :type seed: non-negative integer
        :param seed: seed for random operators.
        """
        if (not isinstance(batch_size, int) or isinstance(batch_size, bool)
                or batch_size <= 0):
            raise ValueError("batch_size should be a positive integer value, "
                             "but got batch_size={}".format(batch_size))
        if not isinstance(drop_last, bool):
            raise ValueError("drop_last should be a boolean value, but got "
                             "drop_last={}".format(drop_last))
        if num_samples is not None and (not isinstance(num_samples, int)
                                        or isinstance(num_samples, bool)
                                        or num_samples <= 0):
            raise ValueError(
                "num_samples should be a positive integer "
                "value, but got num_samples={}".format(num_samples))

        self.batch_size = batch_size
        self.dataset = dataset
        self.drop_last = drop_last

        if world_size is None:
            world_size = dist.get_world_size() if dist.is_distributed() else 1
        self.world_size = world_size
        if rank is None:
            rank = dist.get_rank() if dist.is_distributed() else 0
        self.rank = rank

        if num_samples is None:
            num_samples = len(self.dataset)
        self.num_samples = int(math.ceil(num_samples / self.world_size))

        # Make sure seeds are the same at each rank
        if seed is None and self.world_size > 1:
            seed = 0
        self.rng = np.random.RandomState(seed)
Exemplo n.º 3
0
def train_generator_batch(image, label, *, opt, netG, netloss):
    netG.train()
    B, T, _, H, W = image.shape
    # image
    image_S = image.reshape((B * T, -1, H, W))
    image_S = F.interpolate(image_S, scale_factor=[0.25, 0.25])
    image_S = F.interpolate(image_S, size=[H, W])
    image_S = image_S.reshape((B, T, -1, H, W))
    image_D = image - image_S
    # label
    label_S = label.reshape((B * T, -1, 4 * H, 4 * W))
    label_S = F.interpolate(label_S, scale_factor=[0.25, 0.25])
    label_S = F.interpolate(label_S, size=[4 * H, 4 * W])
    label_S = label_S.reshape((B, T, -1, 4 * H, 4 * W))
    label_D = label - label_S

    HR_G = []
    HR_D = []
    HR_S = []

    # first frame
    pre_S_hat = mge.tensor(
        np.zeros((B, hidden_channels, H, W), dtype=np.float32))
    pre_D_hat = F.zeros_like(pre_S_hat)
    pre_SD = F.zeros_like(pre_S_hat)
    LR = F.concat([
        F.add_axis(image[:, 2, ...], axis=1),
        F.add_axis(image[:, 1, ...], axis=1), image[:, 0:3, ...]
    ],
                  axis=1)
    LR_S = F.concat([
        F.add_axis(image_S[:, 2, ...], axis=1),
        F.add_axis(image_S[:, 1, ...], axis=1), image_S[:, 0:3, ...]
    ],
                    axis=1)
    LR_D = F.concat([
        F.add_axis(image_D[:, 2, ...], axis=1),
        F.add_axis(image_D[:, 1, ...], axis=1), image_D[:, 0:3, ...]
    ],
                    axis=1)
    imgHR, pre_SD, pre_S_hat, pre_D_hat, img_S, img_D = netG(
        LR, LR_S, LR_D, pre_S_hat, pre_D_hat, pre_SD)
    # first frame result
    HR_G.append(F.add_axis(imgHR, axis=1))
    HR_D.append(F.add_axis(img_D, axis=1))
    HR_S.append(F.add_axis(img_S, axis=1))

    # second frame
    LR = F.concat([F.add_axis(image[:, 1, ...], axis=1), image[:, 0:4, ...]],
                  axis=1)
    LR_S = F.concat(
        [F.add_axis(image_S[:, 1, ...], axis=1), image_S[:, 0:4, ...]], axis=1)
    LR_D = F.concat(
        [F.add_axis(image_D[:, 1, ...], axis=1), image_D[:, 0:4, ...]], axis=1)
    imgHR, pre_SD, pre_S_hat, pre_D_hat, img_S, img_D = netG(
        LR, LR_S, LR_D, pre_S_hat, pre_D_hat, pre_SD)
    # second frame result
    HR_G.append(F.add_axis(imgHR, axis=1))
    HR_D.append(F.add_axis(img_D, axis=1))
    HR_S.append(F.add_axis(img_S, axis=1))

    for t in range(2, T - 2):
        imgHR, pre_SD, pre_S_hat, pre_D_hat, img_S, img_D = netG(
            image[:, t - 2:t + 3, ...], image_S[:, t - 2:t + 3, ...],
            image_D[:, t - 2:t + 3, ...], pre_S_hat, pre_D_hat, pre_SD)
        HR_G.append(F.add_axis(imgHR, axis=1))
        HR_D.append(F.add_axis(img_D, axis=1))
        HR_S.append(F.add_axis(img_S, axis=1))

    # T-2 frame
    LR = F.concat(
        [image[:, T - 4:T, ...],
         F.add_axis(image[:, -2, ...], axis=1)],
        axis=1)
    LR_S = F.concat(
        [image_S[:, T - 4:T, ...],
         F.add_axis(image_S[:, -2, ...], axis=1)],
        axis=1)
    LR_D = F.concat(
        [image_D[:, T - 4:T, ...],
         F.add_axis(image_D[:, -2, ...], axis=1)],
        axis=1)
    imgHR, pre_SD, pre_S_hat, pre_D_hat, img_S, img_D = netG(
        LR, LR_S, LR_D, pre_S_hat, pre_D_hat, pre_SD)
    # T-2 frame result
    HR_G.append(F.add_axis(imgHR, axis=1))
    HR_D.append(F.add_axis(img_D, axis=1))
    HR_S.append(F.add_axis(img_S, axis=1))

    # T-1 frame
    LR = F.concat([
        image[:, T - 3:T, ...],
        F.add_axis(image[:, -2, ...], axis=1),
        F.add_axis(image[:, -3, ...], axis=1)
    ],
                  axis=1)
    LR_S = F.concat([
        image_S[:, T - 3:T, ...],
        F.add_axis(image_S[:, -2, ...], axis=1),
        F.add_axis(image_S[:, -3, ...], axis=1)
    ],
                    axis=1)
    LR_D = F.concat([
        image_D[:, T - 3:T, ...],
        F.add_axis(image_D[:, -2, ...], axis=1),
        F.add_axis(image_D[:, -3, ...], axis=1)
    ],
                    axis=1)
    imgHR, pre_SD, pre_S_hat, pre_D_hat, img_S, img_D = netG(
        LR, LR_S, LR_D, pre_S_hat, pre_D_hat, pre_SD)
    # T-1 frame result
    HR_G.append(F.add_axis(imgHR, axis=1))
    HR_D.append(F.add_axis(img_D, axis=1))
    HR_S.append(F.add_axis(img_S, axis=1))

    HR_G = F.concat(HR_G, axis=1)
    HR_D = F.concat(HR_D, axis=1)
    HR_S = F.concat(HR_S, axis=1)
    # assert HR_G.shape == HR_D.shape and HR_D.shape == HR_S.shape # [B,T,C,H,W]
    loss = netloss(HR_G, HR_D, HR_S, label, label_D, label_S)
    opt.backward(loss)
    if dist.is_distributed():
        # do all reduce mean
        pass
    return loss