Exemple #1
0
def wflw_data_iterator(data_dir=None,
                       dataset_mode="encoder_ref",
                       mode="train",
                       use_reference=False,
                       batch_size=1,
                       shuffle=True,
                       rng=None,
                       with_memory_cache=False,
                       with_file_cache=False,
                       transform=None):
    if use_reference:
        logger.info(
            'WFLW Dataset for Encoder using reference .npz file is created.')
        return data_iterator(
            WFLWDataEncoderRefSource(data_dir,
                                     shuffle=shuffle,
                                     rng=rng,
                                     transform=transform,
                                     mode=mode), batch_size, rng,
            with_memory_cache, with_file_cache)

    else:
        logger.info('WFLW Dataset for Encoder is created.')
        return data_iterator(
            WFLWDataEncoderSource(data_dir,
                                  shuffle=shuffle,
                                  rng=rng,
                                  transform=transform,
                                  mode=mode), batch_size, rng,
            with_memory_cache, with_file_cache)
def celebv_data_iterator(dataset_mode=None, celeb_name=None, data_dir=None, ref_dir=None,
                         mode="all", batch_size=1, shuffle=False, rng=None,
                         with_memory_cache=False, with_file_cache=False,
                         resize_size=(64, 64), line_thickness=3, gaussian_kernel=(5, 5), gaussian_sigma=3
                         ):

    if dataset_mode == 'transformer':
        if ref_dir:
            assert os.path.exists(ref_dir), f'{ref_dir} not found.'
            logger.info(
                'CelebV Dataiterator using reference .npz file for Transformer is created.')
            return data_iterator(CelebVDataRefSource(
                                celeb_name=celeb_name, data_dir=data_dir, ref_dir=ref_dir,
                                need_image=False, need_heatmap=True, need_resized_heatmap=False,
                                mode=mode, shuffle=shuffle, rng=rng),
                                batch_size, rng, with_memory_cache, with_file_cache)

        else:
            logger.info('CelebV Dataiterator for Transformer is created.')
            return data_iterator(CelebVDataSource(
                            celeb_name=celeb_name, data_dir=data_dir,
                            need_image=False, need_heatmap=True, need_resized_heatmap=False,
                            mode=mode, shuffle=shuffle, rng=rng,
                            resize_size=resize_size, line_thickness=line_thickness,
                            gaussian_kernel=gaussian_kernel, gaussian_sigma=gaussian_sigma),
                            batch_size, rng, with_memory_cache, with_file_cache)

    elif dataset_mode == 'decoder':
        if ref_dir:
            assert os.path.exists(ref_dir), f'{ref_dir} not found.'
            logger.info(
                'CelebV Dataiterator using reference .npz file for Decoder is created.')
            return data_iterator(CelebVDataRefSource(
                                celeb_name=celeb_name, data_dir=data_dir, ref_dir=ref_dir,
                                need_image=True, need_heatmap=True, need_resized_heatmap=True,
                                mode=mode, shuffle=shuffle, rng=rng),
                                batch_size, rng, with_memory_cache, with_file_cache)

        else:
            logger.info('CelebV Dataiterator for Decoder is created.')
            return data_iterator(CelebVDataSource(
                            celeb_name=celeb_name, data_dir=data_dir,
                            need_image=True, need_heatmap=True, need_resized_heatmap=True,
                            mode=mode, shuffle=shuffle, rng=rng,
                            resize_size=resize_size, line_thickness=line_thickness,
                            gaussian_kernel=gaussian_kernel, gaussian_sigma=gaussian_sigma),
                            batch_size, rng, with_memory_cache, with_file_cache)

    else:
        logger.error(
            'Specified Dataitaretor is wrong?  given: {}'.format(dataset_mode))
        import sys
        sys.exit()
def create_data_iterator(batch_size,
                         data_list,
                         load_shape,
                         crop_shape,
                         comm=None,
                         shuffle=True,
                         rng=None,
                         with_memory_cache=False,
                         with_parallel=False,
                         with_file_cache=False,
                         flip=True):

    ds = Ade20kIterator(data_list,
                        load_shape,
                        crop_shape,
                        shuffle=shuffle,
                        rng=rng,
                        flip=flip)

    # ds.slice turns withMemoryCache flag on forcibly.
    # For data augmentation, this is not desirable and ds.slice is not used here.
    ds = _get_sliced_data_source(ds, comm, shuffle)

    return data_iterator(ds, batch_size, with_memory_cache, with_parallel,
                         with_file_cache)
Exemple #4
0
def data_iterator_mnist(batch_size,
                        train=True,
                        rng=None,
                        shuffle=True,
                        with_memory_cache=False,
                        with_file_cache=False):
    '''
    Provide DataIterator with :py:class:`MnistDataSource`
    with_memory_cache and with_file_cache option's default value is all False,
    because :py:class:`MnistDataSource` is able to store all data into memory.

    For example,

    .. code-block:: python

        with data_iterator_mnist(True, batch_size) as di:
            for data in di:
                SOME CODE TO USE data.

    '''
    return data_iterator(MnistDataSource(train=train, shuffle=shuffle, rng=rng),
                         batch_size,
                         rng,
                         with_memory_cache,
                         with_file_cache)
Exemple #5
0
def data_iterator_stl10(batch_size,
                        train=True,
                        rng=None,
                        shuffle=True,
                        with_memory_cache=False,
                        with_file_cache=False,
                        output_dir=None):
    '''
    Provide DataIterator with :py:class:`STL10DataSource`
    with_memory_cache and with_file_cache option's default value is all False,
    because :py:class:`STL10DataSource` is able to store all data into memory.
    '''
    """
    _data_iterator = data_iterator(
        STL10DataSource(train=train, shuffle=shuffle, rng=rng),
        batch_size,
        rng,
        with_memory_cache,
        with_file_cache
    )
    return _data_iterator
    """

    with STL10DataSource(train=train, shuffle=shuffle, rng=rng, output_dir=output_dir) as ds, \
        data_iterator(ds,
                      batch_size,
                      rng=rng,
                      with_memory_cache=with_memory_cache,
                      with_file_cache=with_file_cache) as di:
        yield di
Exemple #6
0
def data_iterator_modelnet40_normal_resampled(
    data_dir: str,
    batch_size: int,
    train: bool,
    shuffle: bool,
    num_points: int,
    normalize: bool,
    stop_exhausted: bool = True,
    with_memory_cache: bool = True,
    with_file_cache: bool = False,
    rng: Optional[int] = None,
) -> DataIterator:
    dataset = ModelNet40NormalResampledDataset(
        data_dir,
        batch_size,
        train,
        shuffle,
        num_points,
        normalize,
    )
    return data_iterator(
        dataset,
        batch_size,
        rng=rng,
        with_memory_cache=with_memory_cache,
        with_file_cache=with_file_cache,
        stop_exhausted=stop_exhausted,
    )
Exemple #7
0
def SimpleDataIterator(batch_size,
                       root_dir,
                       image_size,
                       comm=None,
                       shuffle=True,
                       rng=None,
                       on_memory=True,
                       fix_aspect_ratio=True):
    # get all files
    paths = [
        os.path.join(root_dir, x) for x in os.listdir(root_dir)
        if os.path.splitext(x)[-1] in SUPPORT_IMG_EXTS
    ]

    if len(paths) == 0:
        raise ValueError(
            f"[SimpleDataIterator] '{root_dir}' is not found. "
            "Please make sure that you specify the correct directory path.")

    ds = SimpleDatasource(img_paths=paths,
                          img_size=image_size,
                          rng=rng,
                          on_memory=on_memory,
                          fix_aspect_ratio=fix_aspect_ratio)

    logger.info(f"Initialized data iterator. {ds.size} images are found.")

    ds = _get_sliced_data_source(ds, comm, shuffle)

    return data_iterator(ds,
                         batch_size,
                         with_memory_cache=False,
                         use_thread=True,
                         with_file_cache=False)
Exemple #8
0
def data_iterator_yolo(root,
                       args,
                       batch_size,
                       shuffle=True,
                       train=False,
                       image_sizes=None,
                       image_size_change_freq=640,
                       on_memory_data=None,
                       use_cv2=True,
                       shape=None):

    # "dataItertor for YoloDataSource"
    assert image_size_change_freq % batch_size == 0, 'image_size_change_freq should be divisible by batch_size'

    return data_iterator(YoloDataSource(
        root,
        args,
        shuffle=shuffle,
        train=train,
        image_sizes=image_sizes,
        image_size_change_freq=image_size_change_freq,
        on_memory_data=on_memory_data,
        use_cv2=use_cv2,
        shape=shape),
                         batch_size=batch_size)
Exemple #9
0
def data_iterator_cifar100(batch_size,
                           train=True,
                           rng=None,
                           shuffle=True,
                           with_memory_cache=False,
                           with_parallel=False,
                           with_file_cache=False):
    '''
    Provide DataIterator with :py:class:`Cifar100DataSource`
    with_memory_cache, with_parallel and with_file_cache option's default value is all False,
    because :py:class:`Cifar100DataSource` is able to store all data into memory.

    For example,

    .. code-block:: python

        with data_iterator_cifar100(True, batch_size) as di:
            for data in di:
                SOME CODE TO USE data.

    '''
    with Cifar100DataSource(train=train, shuffle=shuffle, rng=rng) as ds, \
        data_iterator(ds,
                      batch_size,
                      with_memory_cache,
                      with_parallel,
                      with_file_cache) as di:
        yield di
Exemple #10
0
def data_iterator_cifar10(batch_size,
                          train=True,
                          rng=None,
                          shuffle=True,
                          with_memory_cache=False,
                          with_parallel=False,
                          with_file_cache=False):
    '''
    Provide DataIterator with :py:class:`Cifar10DataSource`
    with_memory_cache, with_parallel and with_file_cache option's default value is all False,
    because :py:class:`Cifar10DataSource` is able to store all data into memory.

    For example,

    .. code-block:: python

        with data_iterator_cifar10(True, batch_size) as di:
            for data in di:
                SOME CODE TO USE data.

    '''
    with Cifar10DataSource(train=train, shuffle=shuffle, rng=rng) as ds, \
        data_iterator(ds,
                      batch_size,
                      with_memory_cache,
                      with_parallel,
                      with_file_cache) as di:
        yield di
def create_data_iterator(batch_size, data_list, image_shape, shuffle=True, rng=None,
                         with_memory_cache=False, with_parallel=False, with_file_cache=False, flip=True):
    return data_iterator(CityScapesIterator(data_list, image_shape, shuffle=shuffle, rng=rng, flip=flip),
                         batch_size,
                         with_memory_cache,
                         with_parallel,
                         with_file_cache)
Exemple #12
0
def run(args):
    """Runs the algorithm."""
    Path(hp.output_path).mkdir(parents=True, exist_ok=True)

    # setup nnabla context
    ctx = get_extension_context(args.context, device_id='0')
    nn.set_default_context(ctx)
    hp.comm = CommunicatorWrapper(ctx)
    hp.event = StreamEventHandler(int(hp.comm.ctx.device_id))

    if hp.comm.n_procs > 1 and hp.comm.rank == 0:
        n_procs = hp.comm.n_procs
        logger.info(f'Distributed training with {n_procs} processes.')

    rng = np.random.RandomState(hp.seed)

    # train data
    train_loader = data_iterator(LJSpeechDataSource('metadata_train.csv',
                                                    hp,
                                                    shuffle=True,
                                                    rng=rng),
                                 batch_size=hp.batch_size,
                                 with_memory_cache=False)
    # valid data
    valid_loader = data_iterator(LJSpeechDataSource('metadata_valid.csv',
                                                    hp,
                                                    shuffle=False,
                                                    rng=rng),
                                 batch_size=hp.batch_size,
                                 with_memory_cache=False)
    dataloader = dict(train=train_loader, valid=valid_loader)
    model = Tacotron2(hp)
    # setup optimizer
    anneal_steps = [
        x * (train_loader.size // hp.batch_size) for x in hp.anneal_steps
    ]
    lr_scheduler = AnnealingScheduler(hp.alpha,
                                      warmup=hp.warmup,
                                      anneal_steps=anneal_steps,
                                      anneal_factor=hp.anneal_factor)
    optimizer = Optimizer(weight_decay=hp.weight_decay,
                          max_norm=hp.max_norm,
                          lr_scheduler=lr_scheduler,
                          name='Adam',
                          alpha=hp.alpha)

    Tacotron2Trainer(model, dataloader, optimizer, hp).run()
Exemple #13
0
def jsi_iterator(batch_size,
                 conf,
                 train,
                 with_memory_cache=False,
                 with_file_cache=False,
                 rng=None):
    return data_iterator(JSIData(conf, train, shuffle=True, rng=None),
                         batch_size, rng, with_memory_cache, with_file_cache)
Exemple #14
0
def data_iterator_librispeech(batch_size,
                              data_dir,
                              shuffle=True,
                              rng=None,
                              with_memory_cache=False,
                              with_file_cache=False):
    return data_iterator(
        LibriSpeechDataSource(data_dir, shuffle=shuffle, rng=rng), batch_size,
        rng, with_memory_cache, with_file_cache)
Exemple #15
0
def load_data(bs_train, bs_valid):
    x_train = np.load(os.path.join(args.input, "x_train.npy"))
    y_train = np.load(os.path.join(args.input, "y_shuffle_train.npy"))

    x_val = np.load(os.path.join(args.input, "x_val.npy"))
    y_val = np.load(os.path.join(args.input, "y_val.npy"))

    data_source_train = Cifar10NumpySource(x_train, y_train, shuffle=False)
    data_source_val = Cifar10NumpySource(x_val, y_val)

    train_samples, val_samples = len(data_source_train.labels), len(
        data_source_val.labels
    )
    train_loader = data_iterator(
        data_source_train, bs_train, None, False, False)
    val_loader = data_iterator(data_source_val, bs_valid, None, False, False)

    return train_loader, val_loader, train_samples, val_samples
Exemple #16
0
def get_photo_tourism_dataiterator(config, split, comm):

    print(
        f'Loading {split} images downscaled by a factor of {config.data.downscale}...')
    data_source = PhototourismDataSource(config.data.root, img_downscale=int(config.data.downscale),
                                         use_cache=config.data.use_cache, split=split)
    if split == 'train':
        di = data_iterator(data_source, batch_size=config.train.ray_batch_size)
        if comm is not None:
            di_ = di.slice(
                rng=None, num_of_slices=comm.n_procs, slice_pos=comm.rank)
        else:
            di_ = di
    elif split == 'val':
        di_ = data_iterator(data_source, batch_size=1)

    elif split == 'test':
        return data_source

    return di_
def simple_data_iterator(_data,
                         batch_size,
                         train=True,
                         rng=None,
                         shuffle=True,
                         with_memory_cache=False,
                         with_parallel=False,
                         with_file_cache=False):
    return data_iterator(
        SimpleDataSource(_data, train=train, shuffle=shuffle, rng=rng),
        batch_size, with_memory_cache, with_parallel, with_file_cache)
Exemple #18
0
def get_data_iterator_ffhq(data_config, batch_size, img_size, comm):

    data_source = FFHQData(data_config, img_size)
    data_iterator_ffhq = data_iterator(data_source, batch_size=batch_size)

    if comm is not None:
        if comm.n_procs > 1:
            data_iterator_ffhq = data_iterator_ffhq.slice(
                rng=None, num_of_slices=comm.n_procs, slice_pos=comm.rank)

    return data_iterator_ffhq
Exemple #19
0
def data_iterator_dtumvs(data_source,
                         batch_size,
                         rng=None,
                         with_memory_cache=False,
                         with_file_cache=False):
    '''
    Provide DataIterator with :py:class:`DTUMVSDataSource`
    with_memory_cache and with_file_cache option's default value is all False,
    because :py:class:`DTUMVSDataSource` is able to store all data into memory.
    '''
    return data_iterator(data_source, batch_size, rng, with_memory_cache,
                         with_file_cache)
def create_data_iterator(batch_size, data_list, image_shape, comm=None, shuffle=True, rng=None,
                         with_memory_cache=False, with_parallel=False, with_file_cache=False, flip=True):
    ds = CityScapesIterator(data_list, image_shape,
                            shuffle=shuffle, rng=rng, flip=flip)

    ds = _get_sliced_data_source(ds, comm, shuffle=shuffle)

    return data_iterator(ds,
                         batch_size,
                         with_memory_cache,
                         with_parallel,
                         with_file_cache)
Exemple #21
0
def facade_data_iterator(
        images_root_path,
        batch_size,
        random_crop=True,
        shuffle=True,
        rng=None,
        with_memory_cache=True,
        with_parallel=False,
        with_file_cache=False):
    return data_iterator(FacadeDataSource(images_root_path, random_crop=random_crop, shuffle=shuffle, rng=rng),
                         batch_size,
                         with_memory_cache,
                         with_file_cache)
def data_iterator_cifar100(batch_size,
                           train=True,
                           rng=None,
                           shuffle=True,
                           with_memory_cache=False,
                           with_file_cache=False):
    '''
    Provide DataIterator with :py:class:`Cifar100DataSource`
    with_memory_cache and with_file_cache option's default value is all False,
    because :py:class:`Cifar100DataSource` is able to store all data into memory.
    '''
    data_source = Cifar100DataSource(train=train, shuffle=shuffle, rng=rng)
    return data_source, data_iterator(data_source, batch_size, rng,
                                      with_memory_cache, with_file_cache)
Exemple #23
0
def Cifar10DataIterator(batch_size,
                        image_size=(32, 32),
                        comm=None,
                        shuffle=True,
                        rng=None,
                        train=True):
    ds = Cifar10DataSource(train=train, shuffle=shuffle, rng=rng)

    ds = _get_sliced_data_source(ds, comm, shuffle)

    return data_iterator(ds,
                         batch_size,
                         with_memory_cache=False,
                         use_thread=True,
                         with_file_cache=False)
def get_data_iterator_mix(data_root,
                          comm,
                          batch_size,
                          image_size,
                          img_exts=['content.png', 'style.png', 'mix.png']):

    data_source = MixingFacesData(data_root, image_size, img_exts)

    data_iterator_ = data_iterator(data_source, batch_size=batch_size)

    if comm is not None:
        if comm.n_procs > 1:
            data_iterator_ = data_iterator_.slice(rng=None,
                                                  num_of_slices=comm.n_procs,
                                                  slice_pos=comm.rank)

    return data_iterator_
Exemple #25
0
def get_data_iterator_attribute(data_root,
                                comm,
                                batch_size,
                                image_size,
                                img_exts=['o.png', 'y.png']):

    data_source = AttributeFacesData(data_root, image_size, img_exts)

    data_iterator_ = data_iterator(data_source, batch_size=batch_size)

    if comm is not None:
        if comm.n_procs > 1:
            data_iterator_ = data_iterator_.slice(rng=None,
                                                  num_of_slices=comm.n_procs,
                                                  slice_pos=comm.rank)

    return data_iterator_
def data_iterator_timeseries(dataset_path,
                             batch_size,
                             x_input_length=32,
                             x_output_length=16,
                             x_split_step=32,
                             rng=None,
                             shuffle=True,
                             with_memory_cache=False,
                             with_parallel=False,
                             with_file_cache=False):
    return data_iterator(
        TimeseriesDataSource(dataset_path=dataset_path,
                             x_input_length=x_input_length,
                             x_output_length=x_output_length,
                             x_split_step=x_split_step,
                             shuffle=shuffle,
                             rng=rng), batch_size, with_memory_cache,
        with_parallel, with_file_cache)
def data_iterator_caltech101(batch_size,
                             train=True,
                             rng=None,
                             shuffle=True,
                             width=128,
                             height=128,
                             padding=True,
                             with_memory_cache=False,
                             with_file_cache=False):
    '''
    Provide DataIterator with :py:class:`Caltech101DataSource`
    with_memory_cache and with_file_cache option's default value is all False,
    because :py:class:`Caltech101DataSource` is able to store all data into memory.
    '''
    return data_iterator(Caltech101DataSource(width=width, height=height, padding=padding, train=train, shuffle=shuffle, rng=rng),
                         batch_size,
                         rng,
                         with_memory_cache,
                         with_file_cache)
Exemple #28
0
def run(args):
    # create output path
    Path(hp.output_path).mkdir(parents=True, exist_ok=True)

    # setup nnabla context
    ctx = get_extension_context(args.context, device_id='0')
    nn.set_default_context(ctx)

    hp.comm = CommunicatorWrapper(ctx)
    hp.event = StreamEventHandler(int(hp.comm.ctx.device_id))

    with open(hp.speaker_dir) as f:
        hp.n_speakers = len(f.read().split('\n'))
        logger.info(f'Training data with {hp.n_speakers} speakers.')

    if hp.comm.n_procs > 1 and hp.comm.rank == 0:
        n_procs = hp.comm.n_procs
        logger.info(f'Distributed training with {n_procs} processes.')
    rng = np.random.RandomState(hp.seed)
    train_loader = data_iterator(VCTKDataSource('metadata_train.csv',
                                                hp,
                                                shuffle=True,
                                                rng=rng),
                                 batch_size=hp.batch_size,
                                 with_memory_cache=False,
                                 rng=rng)
    dataloader = dict(train=train_loader, valid=None)
    gen = NVCNet(hp)
    gen_optim = Optimizer(weight_decay=hp.weight_decay,
                          name='Adam',
                          alpha=hp.g_lr,
                          beta1=hp.beta1,
                          beta2=hp.beta2)
    dis = Discriminator(hp)
    dis_optim = Optimizer(weight_decay=hp.weight_decay,
                          name='Adam',
                          alpha=hp.d_lr,
                          beta1=hp.beta1,
                          beta2=hp.beta2)
    Trainer(gen, gen_optim, dis, dis_optim, dataloader, rng, hp).run()
def frame_data_iterator(root_dir,
                        frame_shape=(256, 256, 3),
                        id_sampling=False,
                        is_train=True,
                        random_seed=0,
                        augmentation_params=None,
                        batch_size=1,
                        shuffle=True,
                        with_memory_cache=False,
                        with_file_cache=False):
    return data_iterator(FramesDataSource(
        root_dir=root_dir,
        frame_shape=frame_shape,
        id_sampling=id_sampling,
        is_train=is_train,
        random_seed=random_seed,
        augmentation_params=augmentation_params,
        shuffle=shuffle),
                         batch_size=batch_size,
                         rng=random_seed,
                         with_memory_cache=with_memory_cache,
                         with_file_cache=with_file_cache)
Exemple #30
0
def train():
    # Check NNabla version
    if utils.get_nnabla_version_integer() < 11900:
        raise ValueError(
            'Please update the nnabla version to v1.19.0 or latest version since memory efficiency of core engine is improved in v1.19.0'
        )

    parser, args = get_train_args()

    # Get context.
    ctx = get_extension_context(args.context, device_id=args.device_id)
    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(comm.ctx)
    ext = import_extension_module(args.context)

    # Monitors
    # setting up monitors for logging
    monitor_path = args.output
    monitor = Monitor(monitor_path)

    monitor_best_epoch = MonitorSeries('Best epoch', monitor, interval=1)
    monitor_traing_loss = MonitorSeries('Training loss', monitor, interval=1)
    monitor_validation_loss = MonitorSeries('Validation loss',
                                            monitor,
                                            interval=1)
    monitor_lr = MonitorSeries('learning rate', monitor, interval=1)
    monitor_time = MonitorTimeElapsed("training time per iteration",
                                      monitor,
                                      interval=1)

    if comm.rank == 0:
        print("Mixing coef. is {}, i.e., MDL = {}*TD-Loss + FD-Loss".format(
            args.mcoef, args.mcoef))
        if not os.path.isdir(args.output):
            os.makedirs(args.output)

    # Initialize DataIterator for MUSDB.
    train_source, valid_source, args = load_datasources(parser, args)

    train_iter = data_iterator(train_source,
                               args.batch_size,
                               RandomState(args.seed),
                               with_memory_cache=False,
                               with_file_cache=False)

    valid_iter = data_iterator(valid_source,
                               1,
                               RandomState(args.seed),
                               with_memory_cache=False,
                               with_file_cache=False)

    if comm.n_procs > 1:
        train_iter = train_iter.slice(rng=None,
                                      num_of_slices=comm.n_procs,
                                      slice_pos=comm.rank)

        valid_iter = valid_iter.slice(rng=None,
                                      num_of_slices=comm.n_procs,
                                      slice_pos=comm.rank)

    # Calculate maxiter per GPU device.
    max_iter = int((train_source._size // args.batch_size) // comm.n_procs)
    weight_decay = args.weight_decay * comm.n_procs

    print("max_iter", max_iter)

    # Calculate the statistics (mean and variance) of the dataset
    scaler_mean, scaler_std = utils.get_statistics(args, train_source)

    max_bin = utils.bandwidth_to_max_bin(train_source.sample_rate, args.nfft,
                                         args.bandwidth)

    unmix = OpenUnmix_CrossNet(input_mean=scaler_mean,
                               input_scale=scaler_std,
                               nb_channels=args.nb_channels,
                               hidden_size=args.hidden_size,
                               n_fft=args.nfft,
                               n_hop=args.nhop,
                               max_bin=max_bin)

    # Create input variables.
    mixture_audio = nn.Variable([args.batch_size] +
                                list(train_source._get_data(0)[0].shape))
    target_audio = nn.Variable([args.batch_size] +
                               list(train_source._get_data(0)[1].shape))

    vmixture_audio = nn.Variable(
        [1] + [2, valid_source.sample_rate * args.valid_dur])
    vtarget_audio = nn.Variable([1] +
                                [8, valid_source.sample_rate * args.valid_dur])

    # create training graph
    mix_spec, M_hat, pred = unmix(mixture_audio)
    Y = Spectrogram(*STFT(target_audio, n_fft=unmix.n_fft, n_hop=unmix.n_hop),
                    mono=(unmix.nb_channels == 1))
    loss_f = mse_loss(mix_spec, M_hat, Y)
    loss_t = sdr_loss(mixture_audio, pred, target_audio)
    loss = args.mcoef * loss_t + loss_f
    loss.persistent = True

    # Create Solver and set parameters.
    solver = S.Adam(args.lr)
    solver.set_parameters(nn.get_parameters())

    # create validation graph
    vmix_spec, vM_hat, vpred = unmix(vmixture_audio, test=True)
    vY = Spectrogram(*STFT(vtarget_audio, n_fft=unmix.n_fft,
                           n_hop=unmix.n_hop),
                     mono=(unmix.nb_channels == 1))
    vloss_f = mse_loss(vmix_spec, vM_hat, vY)
    vloss_t = sdr_loss(vmixture_audio, vpred, vtarget_audio)
    vloss = args.mcoef * vloss_t + vloss_f
    vloss.persistent = True

    # Initialize Early Stopping
    es = utils.EarlyStopping(patience=args.patience)

    # Initialize LR Scheduler (ReduceLROnPlateau)
    lr_scheduler = ReduceLROnPlateau(lr=args.lr,
                                     factor=args.lr_decay_gamma,
                                     patience=args.lr_decay_patience)
    best_epoch = 0

    # Training loop.
    for epoch in trange(args.epochs):
        # TRAINING
        losses = utils.AverageMeter()
        for batch in range(max_iter):
            mixture_audio.d, target_audio.d = train_iter.next()
            solver.zero_grad()
            loss.forward(clear_no_need_grad=True)
            if comm.n_procs > 1:
                all_reduce_callback = comm.get_all_reduce_callback()
                loss.backward(clear_buffer=True,
                              communicator_callbacks=all_reduce_callback)
            else:
                loss.backward(clear_buffer=True)
            solver.weight_decay(weight_decay)
            solver.update()
            losses.update(loss.d.copy(), args.batch_size)
        training_loss = losses.avg

        # clear cache memory
        ext.clear_memory_cache()

        # VALIDATION
        vlosses = utils.AverageMeter()
        for batch in range(int(valid_source._size // comm.n_procs)):
            x, y = valid_iter.next()
            dur = int(valid_source.sample_rate * args.valid_dur)
            sp, cnt = 0, 0
            loss_tmp = nn.NdArray()
            loss_tmp.zero()
            while 1:
                vmixture_audio.d = x[Ellipsis, sp:sp + dur]
                vtarget_audio.d = y[Ellipsis, sp:sp + dur]
                vloss.forward(clear_no_need_grad=True)
                cnt += 1
                sp += dur
                loss_tmp += vloss.data
                if x[Ellipsis,
                     sp:sp + dur].shape[-1] < dur or x.shape[-1] == cnt * dur:
                    break
            loss_tmp = loss_tmp / cnt
            if comm.n_procs > 1:
                comm.all_reduce(loss_tmp, division=True, inplace=True)
            vlosses.update(loss_tmp.data.copy(), 1)
        validation_loss = vlosses.avg

        # clear cache memory
        ext.clear_memory_cache()

        lr = lr_scheduler.update_lr(validation_loss, epoch=epoch)
        solver.set_learning_rate(lr)
        stop = es.step(validation_loss)

        if comm.rank == 0:
            monitor_best_epoch.add(epoch, best_epoch)
            monitor_traing_loss.add(epoch, training_loss)
            monitor_validation_loss.add(epoch, validation_loss)
            monitor_lr.add(epoch, lr)
            monitor_time.add(epoch)

            if validation_loss == es.best:
                # save best model
                nn.save_parameters(os.path.join(args.output, 'best_xumx.h5'))
                best_epoch = epoch

        if stop:
            print("Apply Early Stopping")
            break
def cycle_gan_data_iterator(data_source, batch_size):
    return data_iterator(data_source,
                         batch_size=batch_size,
                         with_memory_cache=False,
                         with_file_cache=False)