コード例 #1
0
    def __init__(self, ext_name, device_id):
        super(CpuTimerCallback, self).__init__()

        self.results = OrderedDict()
        self.device_id = str(device_id)
        self.ext_module = import_extension_module(ext_name)
        self.key_to_times = OrderedDict()
コード例 #2
0
    def __init__(self, ext_name, device_id):
        """
        Args:
             ext_name (str): backend extension name (e.g. cpu, cuda, or cudnn)
             device_id (str): device id
        """
        if ext_name == "cpu":
            self.profiler = import_module(
                "nnabla.utils.inspection").CpuTimerCallback

        elif ext_name in ["cuda", "cudnn"]:
            self.profiler = import_extension_module(
                "cuda.utils.inspection").CudaEventTimerCallback

        else:
            # Unsupported extension.
            raise NotImplementedError(
                "Profiler for the extension '{}' is not implemented.".format(
                    ext_name))

        self.ext_name = ext_name
        self.device_id = device_id
        self._scope_name = ""
        self.profilers = {}
        self.create_new_profiler("summary")
コード例 #3
0
def test_ext_utils_misc(ext_name):
    ext = ext_utils.import_extension_module(ext_name)
    ext.clear_memory_cache()
    if ext.get_device_count() == 0:
        return
    ds = ext.get_devices()
    print(ds)
    ext.device_synchronize(ds[0])
コード例 #4
0
def profile_command(args):
    configure_progress(os.path.join(args.outdir, 'progress.txt'))
    files = []
    files.append(args.config)

    class TrainConfig:
        pass

    config = TrainConfig()
    info = load.load(files)

    config.global_config = info.global_config
    config.training_config = info.training_config

    class OptConfig:
        pass

    config.optimizers = OrderedDict()
    for name, opt in info.optimizers.items():
        o = OptConfig()
        o.optimizer = opt
        o.data_iterator = None
        config.optimizers[name] = o

    class MonConfig:
        pass

    config.monitors = OrderedDict()
    for name, mon in info.monitors.items():
        m = MonConfig()
        m.monitor = mon
        m.data_iterator = None
        config.monitors[name] = m

    ext_module = import_extension_module(
        config.global_config.default_context.backend[0].split(':')[0])

    def synchronize():
        return ext_module.synchronize(
            device_id=config.global_config.default_context.device_id)

    result_array = [['time in ms']]

    # Profile Optimizer
    with ExitStack() as stack:
        for name, o in config.optimizers.items():
            o.data_iterator = stack.enter_context(o.optimizer.data_iterator())
        result_array = profile_optimizer(config, result_array, synchronize)

    # Write profiling result
    import csv
    with open(args.outdir + os.sep + 'profile.csv', 'w') as f:
        writer = csv.writer(f, lineterminator='\n')
        writer.writerows(result_array)

    logger.log(99, 'Profile Completed.')
    progress(None)
    return True
コード例 #5
0
ファイル: profiler.py プロジェクト: gishikawa3/my_nnabla
    def __init__(self, graph, device_id, ext_name, solver=None, n_run=100, max_measure_execution_time=1,
                 time_scale="m"):
        self.graph = graph
        # if solver is None, training time (forward + backward + update) is not calculated
        self.solver = solver
        self.n_run = n_run
        self.device_id = str(device_id)
        self.ext_name = ext_name
        self.ext_module = import_extension_module(self.ext_name)
        self.max_measure_execution_time = max_measure_execution_time
        self.time_scale = time_scale
        self.result = dict()
        self.name2val = {v: k for k, v in nn.get_parameters().items()}

        if self.n_run < 1:
            raise AssertionError("n_run must be bigger than 1")
コード例 #6
0
def main():
    args = get_args()
    rng = np.random.RandomState(1223)

    # Get context
    from nnabla.ext_utils import get_extension_context, import_extension_module
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)
    ext = import_extension_module(args.context)

    # read label file
    f = open(args.label_file_path, "r")
    labels_dict = f.readlines()

    # Load parameters
    _ = nn.load_parameters(args.model_load_path)

    # Build a Deeplab v3+ network
    x = nn.Variable((1, 3, args.image_height, args.image_width),
                    need_grad=False)
    y = net.deeplabv3plus_model(x,
                                args.output_stride,
                                args.num_class,
                                test=True)

    # preprocess image
    image = imageio.imread(args.test_image_file, as_gray=False, pilmode="RGB")
    #image = imread(args.test_image_file).astype('float32')
    orig_h, orig_w, orig_c = image.shape
    old_size = (orig_h, orig_w)

    input_array = image_preprocess.preprocess_image_and_label(
        image,
        label=None,
        target_width=args.image_width,
        target_height=args.image_height,
        train=False)
    print('Input', input_array.shape)
    input_array = np.transpose(input_array, (2, 0, 1))
    input_array = np.reshape(
        input_array,
        (1, input_array.shape[0], input_array.shape[1], input_array.shape[2]))

    # Compute inference and inference time
    t = time.time()

    x.d = input_array
    y.forward(clear_buffer=True)
    print("done")
    available_devices = ext.get_devices()
    ext.device_synchronize(available_devices[0])
    ext.clear_memory_cache()

    elapsed = time.time() - t
    print('Inference time : %s seconds' % (elapsed))

    output = np.argmax(y.d, axis=1)  # (batch,h,w)

    # Apply post processing
    post_processed = post_process(output[0], old_size,
                                  (args.image_height, args.image_width))

    # Get the classes predicted
    predicted_classes = np.unique(post_processed)
    for i in range(predicted_classes.shape[0]):
        print('Classes Segmented: ', labels_dict[predicted_classes[i]])

    # Visualize inference result
    visualize(post_processed)
コード例 #7
0
def train():
    # Check NNabla version
    if utils.get_nnabla_version_integer() < 11900:
        raise ValueError(
            'Please update the nnabla version to v1.19.0 or latest version since memory efficiency of core engine is improved in v1.19.0'
        )

    parser, args = get_train_args()

    # Get context.
    ctx = get_extension_context(args.context, device_id=args.device_id)
    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(comm.ctx)
    ext = import_extension_module(args.context)

    # Monitors
    # setting up monitors for logging
    monitor_path = args.output
    monitor = Monitor(monitor_path)

    monitor_best_epoch = MonitorSeries('Best epoch', monitor, interval=1)
    monitor_traing_loss = MonitorSeries('Training loss', monitor, interval=1)
    monitor_validation_loss = MonitorSeries('Validation loss',
                                            monitor,
                                            interval=1)
    monitor_lr = MonitorSeries('learning rate', monitor, interval=1)
    monitor_time = MonitorTimeElapsed("training time per iteration",
                                      monitor,
                                      interval=1)

    if comm.rank == 0:
        print("Mixing coef. is {}, i.e., MDL = {}*TD-Loss + FD-Loss".format(
            args.mcoef, args.mcoef))
        if not os.path.isdir(args.output):
            os.makedirs(args.output)

    # Initialize DataIterator for MUSDB.
    train_source, valid_source, args = load_datasources(parser, args)

    train_iter = data_iterator(train_source,
                               args.batch_size,
                               RandomState(args.seed),
                               with_memory_cache=False,
                               with_file_cache=False)

    valid_iter = data_iterator(valid_source,
                               1,
                               RandomState(args.seed),
                               with_memory_cache=False,
                               with_file_cache=False)

    if comm.n_procs > 1:
        train_iter = train_iter.slice(rng=None,
                                      num_of_slices=comm.n_procs,
                                      slice_pos=comm.rank)

        valid_iter = valid_iter.slice(rng=None,
                                      num_of_slices=comm.n_procs,
                                      slice_pos=comm.rank)

    # Calculate maxiter per GPU device.
    max_iter = int((train_source._size // args.batch_size) // comm.n_procs)
    weight_decay = args.weight_decay * comm.n_procs

    print("max_iter", max_iter)

    # Calculate the statistics (mean and variance) of the dataset
    scaler_mean, scaler_std = utils.get_statistics(args, train_source)

    max_bin = utils.bandwidth_to_max_bin(train_source.sample_rate, args.nfft,
                                         args.bandwidth)

    unmix = OpenUnmix_CrossNet(input_mean=scaler_mean,
                               input_scale=scaler_std,
                               nb_channels=args.nb_channels,
                               hidden_size=args.hidden_size,
                               n_fft=args.nfft,
                               n_hop=args.nhop,
                               max_bin=max_bin)

    # Create input variables.
    mixture_audio = nn.Variable([args.batch_size] +
                                list(train_source._get_data(0)[0].shape))
    target_audio = nn.Variable([args.batch_size] +
                               list(train_source._get_data(0)[1].shape))

    vmixture_audio = nn.Variable(
        [1] + [2, valid_source.sample_rate * args.valid_dur])
    vtarget_audio = nn.Variable([1] +
                                [8, valid_source.sample_rate * args.valid_dur])

    # create training graph
    mix_spec, M_hat, pred = unmix(mixture_audio)
    Y = Spectrogram(*STFT(target_audio, n_fft=unmix.n_fft, n_hop=unmix.n_hop),
                    mono=(unmix.nb_channels == 1))
    loss_f = mse_loss(mix_spec, M_hat, Y)
    loss_t = sdr_loss(mixture_audio, pred, target_audio)
    loss = args.mcoef * loss_t + loss_f
    loss.persistent = True

    # Create Solver and set parameters.
    solver = S.Adam(args.lr)
    solver.set_parameters(nn.get_parameters())

    # create validation graph
    vmix_spec, vM_hat, vpred = unmix(vmixture_audio, test=True)
    vY = Spectrogram(*STFT(vtarget_audio, n_fft=unmix.n_fft,
                           n_hop=unmix.n_hop),
                     mono=(unmix.nb_channels == 1))
    vloss_f = mse_loss(vmix_spec, vM_hat, vY)
    vloss_t = sdr_loss(vmixture_audio, vpred, vtarget_audio)
    vloss = args.mcoef * vloss_t + vloss_f
    vloss.persistent = True

    # Initialize Early Stopping
    es = utils.EarlyStopping(patience=args.patience)

    # Initialize LR Scheduler (ReduceLROnPlateau)
    lr_scheduler = ReduceLROnPlateau(lr=args.lr,
                                     factor=args.lr_decay_gamma,
                                     patience=args.lr_decay_patience)
    best_epoch = 0

    # Training loop.
    for epoch in trange(args.epochs):
        # TRAINING
        losses = utils.AverageMeter()
        for batch in range(max_iter):
            mixture_audio.d, target_audio.d = train_iter.next()
            solver.zero_grad()
            loss.forward(clear_no_need_grad=True)
            if comm.n_procs > 1:
                all_reduce_callback = comm.get_all_reduce_callback()
                loss.backward(clear_buffer=True,
                              communicator_callbacks=all_reduce_callback)
            else:
                loss.backward(clear_buffer=True)
            solver.weight_decay(weight_decay)
            solver.update()
            losses.update(loss.d.copy(), args.batch_size)
        training_loss = losses.avg

        # clear cache memory
        ext.clear_memory_cache()

        # VALIDATION
        vlosses = utils.AverageMeter()
        for batch in range(int(valid_source._size // comm.n_procs)):
            x, y = valid_iter.next()
            dur = int(valid_source.sample_rate * args.valid_dur)
            sp, cnt = 0, 0
            loss_tmp = nn.NdArray()
            loss_tmp.zero()
            while 1:
                vmixture_audio.d = x[Ellipsis, sp:sp + dur]
                vtarget_audio.d = y[Ellipsis, sp:sp + dur]
                vloss.forward(clear_no_need_grad=True)
                cnt += 1
                sp += dur
                loss_tmp += vloss.data
                if x[Ellipsis,
                     sp:sp + dur].shape[-1] < dur or x.shape[-1] == cnt * dur:
                    break
            loss_tmp = loss_tmp / cnt
            if comm.n_procs > 1:
                comm.all_reduce(loss_tmp, division=True, inplace=True)
            vlosses.update(loss_tmp.data.copy(), 1)
        validation_loss = vlosses.avg

        # clear cache memory
        ext.clear_memory_cache()

        lr = lr_scheduler.update_lr(validation_loss, epoch=epoch)
        solver.set_learning_rate(lr)
        stop = es.step(validation_loss)

        if comm.rank == 0:
            monitor_best_epoch.add(epoch, best_epoch)
            monitor_traing_loss.add(epoch, training_loss)
            monitor_validation_loss.add(epoch, validation_loss)
            monitor_lr.add(epoch, lr)
            monitor_time.add(epoch)

            if validation_loss == es.best:
                # save best model
                nn.save_parameters(os.path.join(args.output, 'best_xumx.h5'))
                best_epoch = epoch

        if stop:
            print("Apply Early Stopping")
            break
コード例 #8
0
def test_lms(type_config, device_id, batch_size, num_dilations, learning_rate,
             max_iter, gpu_memory_size, max_prefetch_bytes, cast_prefetch,
             memory):

    import nnabla_ext.cuda.init as cuda_init

    # Use pinned host memory
    cuda_init.prefer_cpu_pinned_array()

    # Change a type of memory allocator for device
    if memory == "virtual":
        cuda_init.prefer_cuda_virtual_array()

        from nnabla_ext.cuda.init import set_cuda_virtual_memory_chunk_size
        set_cuda_virtual_memory_chunk_size(2 << 20)
    elif memory == "cached":
        cuda_init.prefer_cuda_cached_array()

    # Set context.
    from nnabla.ext_utils import get_extension_context
    cpu_ctx = get_extension_context('cpu',
                                    device_id='',
                                    type_config=type_config)
    gpu_ctx = get_extension_context('cudnn',
                                    device_id=device_id,
                                    type_config=type_config)
    nn.set_default_context(gpu_ctx)

    # Input data
    x0 = []
    t0 = []

    with tempfile.NamedTemporaryFile(mode='w', suffix='.h5',
                                     delete=False) as tf:
        # Init inputs
        np.random.seed(seed=32)
        for i in range(max_iter):
            x0 += [
                np.random.randint(0,
                                  256,
                                  size=(batch_size, data_config.duration, 1))
            ]
            t0 += [
                np.random.randint(0,
                                  256,
                                  size=(batch_size, data_config.duration, 1))
            ]

        initial_param1 = []
        initial_param2 = []

        ###############################
        # Normal
        ###############################
        with nn.parameter_scope('network1'):
            # Create network
            x, t, loss, solver = create_network(batch_size, num_dilations,
                                                learning_rate)

            # Store the initial parameter
            nn.save_parameters(tf.name)

            # Load the initial parameter
            nn.load_parameters(tf.name)
            for k, p in nn.get_parameters(grad_only=False).items():
                initial_param1.append(p.d)

            # Training loop.
            for i in range(max_iter):
                x.d = x0[i]
                t.d = t0[i]

                loss.forward(clear_no_need_grad=True)

                solver.zero_grad()
                loss.backward(clear_buffer=True)

                solver.update()

            # Synchronization
            ext = ext_utils.import_extension_module('cudnn')
            ext.device_synchronize(device_id)

        ###############################
        # Swap in/out scheduler
        ###############################
        with nn.parameter_scope('network2'):
            # Create network
            x, t, loss, solver = create_network(batch_size, num_dilations,
                                                learning_rate)

            # Load the initial parameter
            nn.load_parameters(tf.name)
            for k, p in nn.get_parameters(grad_only=False).items():
                initial_param2.append(p.d)

            # Create a scheduler
            scheduler = lms.SwapInOutScheduler(cpu_ctx, gpu_ctx,
                                               gpu_memory_size,
                                               max_prefetch_bytes,
                                               cast_prefetch)

            # Training loop.
            for i in range(max_iter):
                with scheduler:
                    x.d = x0[i]
                    t.d = t0[i]

                    loss.forward(clear_no_need_grad=True)

                    solver.zero_grad()
                    loss.backward(clear_buffer=True)

                    solver.update()

            # Synchronization
            ext = ext_utils.import_extension_module('cudnn')
            ext.device_synchronize(device_id)

        ###############################
        # Test
        ###############################
        for p1, p2 in zip(initial_param1, initial_param2):
            # Check the identity of initial parameters
            assert np.array_equal(p1, p2)

        with nn.parameter_scope('network1'):
            param1 = nn.get_parameters(grad_only=False)

        with nn.parameter_scope('network2'):
            param2 = nn.get_parameters(grad_only=False)

        for (k1, p1), (k2, p2) in zip(param1.items(), param2.items()):
            assert_allclose(p1.d, p2.d, atol=2e-3, rtol=2e-3)

        # Remove the file
        tf.close()
        os.remove(tf.name)
        assert not os.path.exists(tf.name)
コード例 #9
0
ファイル: profile.py プロジェクト: alikabeel-sony/nnabla
def profile_command(args):
    callback.update_status(args)

    configure_progress(os.path.join(args.outdir, 'progress.txt'))

    class TrainConfig:
        pass

    config = TrainConfig()
    info = load.load(args.config)

    config.global_config = info.global_config
    config.training_config = info.training_config

    class OptConfig:
        pass

    config.optimizers = OrderedDict()
    for name, opt in info.optimizers.items():
        o = OptConfig()
        o.optimizer = opt
        o.data_iterators = []
        config.optimizers[name] = o

    class MonConfig:
        pass

    config.monitors = OrderedDict()
    for name, mon in info.monitors.items():
        m = MonConfig()
        m.monitor = mon
        m.data_iterators = []
        config.monitors[name] = m

    ext_module = import_extension_module(
        config.global_config.default_context.backend[0].split(':')[0])

    def synchronize():
        return ext_module.synchronize(
            device_id=config.global_config.default_context.device_id)

    result_array = [['time in ms']]

    callback.update_status('processing', True)

    # Profile Optimizer
    with ExitStack() as stack:
        # Create data_iterator instance only once for each dataset in optimizers
        optimizer_data_iterators = {}
        for name, o in config.optimizers.items():
            for di in o.optimizer.data_iterators.values():
                if di not in optimizer_data_iterators:
                    di_instance = stack.enter_context(di())
                    optimizer_data_iterators[di] = di_instance
                else:
                    di_instance = optimizer_data_iterators[di]
                o.data_iterators.append(di_instance)
        result_array = profile_optimizer(config, result_array, synchronize)

    # Write profiling result
    import csv
    with open(args.outdir + os.sep + 'profile.csv', 'w') as f:
        writer = csv.writer(f, lineterminator='\n')
        writer.writerows(result_array)

    logger.log(99, 'Profile Completed.')
    progress(None)
    callback.update_status('finished')
    return True
コード例 #10
0
def test_import_extension_module(ext_name):
    ext = ext_utils.import_extension_module(ext_name)
コード例 #11
0
ファイル: eval.py プロジェクト: sony/ai-research-code
    scores = museval.eval_mus_track(track, estimates, output_dir=args.out_dir)
    # clear cache memory
    ext.clear_memory_cache()
    return scores


if __name__ == '__main__':
    # Get the arguments parser
    args = get_inference_args()

    # Set NNabla context and Dynamic graph execution
    ctx = get_extension_context(args.context)
    nn.set_default_context(ctx)
    nn.set_auto_forward(True)
    ext = import_extension_module(args.context)

    mus = musdb.DB(root=args.root,
                   download=args.root is None,
                   subsets='test',
                   is_wav=args.is_wav)

    if args.cores > 1:
        pool = multiprocessing.Pool(args.cores)
        results = museval.EvalStore()
        scores_list = list(
            pool.imap_unordered(func=functools.partial(separate_and_evaluate,
                                                       args, ext),
                                iterable=mus.tracks,
                                chunksize=1))
        pool.close()
コード例 #12
0
def train():
    # Check NNabla version
    if utils.get_nnabla_version_integer() < 11900:
        raise ValueError(
            'Please update the nnabla version to v1.19.0 or latest version since memory efficiency of core engine is improved in v1.19.0'
        )

    parser, args = get_train_args()

    # Get context.
    ctx = get_extension_context(args.context, device_id=args.device_id)
    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(comm.ctx)
    ext = import_extension_module(args.context)

    # Monitors
    # setting up monitors for logging
    monitor_path = args.output
    monitor = Monitor(monitor_path)

    monitor_best_epoch = MonitorSeries('Best epoch', monitor, interval=1)
    monitor_traing_loss = MonitorSeries('Training loss', monitor, interval=1)
    monitor_validation_loss = MonitorSeries('Validation loss',
                                            monitor,
                                            interval=1)
    monitor_lr = MonitorSeries('learning rate', monitor, interval=1)
    monitor_time = MonitorTimeElapsed("training time per iteration",
                                      monitor,
                                      interval=1)

    if comm.rank == 0:
        if not os.path.isdir(args.output):
            os.makedirs(args.output)

    # Initialize DataIterator for MUSDB18.
    train_source, valid_source, args = load_datasources(parser, args)

    train_iter = data_iterator(
        train_source,
        args.batch_size,
        RandomState(args.seed),
        with_memory_cache=False,
    )

    valid_iter = data_iterator(
        valid_source,
        1,
        RandomState(args.seed),
        with_memory_cache=False,
    )

    if comm.n_procs > 1:
        train_iter = train_iter.slice(rng=None,
                                      num_of_slices=comm.n_procs,
                                      slice_pos=comm.rank)

        valid_iter = valid_iter.slice(rng=None,
                                      num_of_slices=comm.n_procs,
                                      slice_pos=comm.rank)

    # Calculate maxiter per GPU device.
    # Change max_iter, learning_rate and weight_decay according no. of gpu devices for multi-gpu training.
    default_batch_size = 16
    train_scale_factor = (comm.n_procs * args.batch_size) / default_batch_size
    max_iter = int((train_source._size // args.batch_size) // comm.n_procs)
    weight_decay = args.weight_decay * train_scale_factor
    args.lr = args.lr * train_scale_factor

    # Calculate the statistics (mean and variance) of the dataset
    scaler_mean, scaler_std = utils.get_statistics(args, train_source)

    # clear cache memory
    ext.clear_memory_cache()

    max_bin = utils.bandwidth_to_max_bin(train_source.sample_rate, args.nfft,
                                         args.bandwidth)

    # Get X-UMX/UMX computation graph and variables as namedtuple
    model = get_model(args, scaler_mean, scaler_std, max_bin=max_bin)

    # Create Solver and set parameters.
    solver = S.Adam(args.lr)
    solver.set_parameters(nn.get_parameters())

    # Initialize Early Stopping
    es = utils.EarlyStopping(patience=args.patience)

    # Initialize LR Scheduler (ReduceLROnPlateau)
    lr_scheduler = ReduceLROnPlateau(lr=args.lr,
                                     factor=args.lr_decay_gamma,
                                     patience=args.lr_decay_patience)
    best_epoch = 0

    # AverageMeter for mean loss calculation over the epoch
    losses = utils.AverageMeter()

    # Training loop.
    for epoch in trange(args.epochs):
        # TRAINING
        losses.reset()
        for batch in range(max_iter):
            model.mixture_audio.d, model.target_audio.d = train_iter.next()
            solver.zero_grad()
            model.loss.forward(clear_no_need_grad=True)
            if comm.n_procs > 1:
                all_reduce_callback = comm.get_all_reduce_callback()
                model.loss.backward(clear_buffer=True,
                                    communicator_callbacks=all_reduce_callback)
            else:
                model.loss.backward(clear_buffer=True)
            solver.weight_decay(weight_decay)
            solver.update()
            losses.update(model.loss.d.copy(), args.batch_size)
        training_loss = losses.get_avg()

        # clear cache memory
        ext.clear_memory_cache()

        # VALIDATION
        losses.reset()
        for batch in range(int(valid_source._size // comm.n_procs)):
            x, y = valid_iter.next()
            dur = int(valid_source.sample_rate * args.valid_dur)
            sp, cnt = 0, 0
            loss_tmp = nn.NdArray()
            loss_tmp.zero()
            while 1:
                model.vmixture_audio.d = x[Ellipsis, sp:sp + dur]
                model.vtarget_audio.d = y[Ellipsis, sp:sp + dur]
                model.vloss.forward(clear_no_need_grad=True)
                cnt += 1
                sp += dur
                loss_tmp += model.vloss.data
                if x[Ellipsis,
                     sp:sp + dur].shape[-1] < dur or x.shape[-1] == cnt * dur:
                    break
            loss_tmp = loss_tmp / cnt
            if comm.n_procs > 1:
                comm.all_reduce(loss_tmp, division=True, inplace=True)
            losses.update(loss_tmp.data.copy(), 1)
        validation_loss = losses.get_avg()

        # clear cache memory
        ext.clear_memory_cache()

        lr = lr_scheduler.update_lr(validation_loss, epoch=epoch)
        solver.set_learning_rate(lr)
        stop = es.step(validation_loss)

        if comm.rank == 0:
            monitor_best_epoch.add(epoch, best_epoch)
            monitor_traing_loss.add(epoch, training_loss)
            monitor_validation_loss.add(epoch, validation_loss)
            monitor_lr.add(epoch, lr)
            monitor_time.add(epoch)

            if validation_loss == es.best:
                best_epoch = epoch
                # save best model
                if args.umx_train:
                    nn.save_parameters(os.path.join(args.output,
                                                    'best_umx.h5'))
                else:
                    nn.save_parameters(
                        os.path.join(args.output, 'best_xumx.h5'))

        if args.umx_train:
            # Early stopping for UMX after `args.patience` (140) number of epochs
            if stop:
                print("Apply Early Stopping")
                break
コード例 #13
0
ファイル: train.py プロジェクト: sony/ai-research-code
def train():
    # Check NNabla version
    if get_nnabla_version_integer() < 11900:
        raise ValueError(
            'Please update the nnabla version to v1.19.0 or latest version since memory efficiency of core engine is improved in v1.19.0'
        )

    parser, args = get_train_args()

    # Get context.
    ctx = get_extension_context(args.context, device_id=args.device_id)
    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(comm.ctx)
    ext = import_extension_module(args.context)

    # Monitors
    # setting up monitors for logging
    monitor_path = os.path.join(args.output, args.target)
    monitor = Monitor(monitor_path)

    monitor_traing_loss = MonitorSeries('Training loss', monitor, interval=1)
    monitor_lr = MonitorSeries('learning rate', monitor, interval=1)
    monitor_time = MonitorTimeElapsed("training time per epoch",
                                      monitor,
                                      interval=1)

    if comm.rank == 0:
        if not os.path.isdir(args.output):
            os.makedirs(args.output)

    # Initialize DataIterator for MUSDB.
    train_source, args = load_datasources(parser, args)

    train_iter = data_iterator(train_source,
                               args.batch_size,
                               RandomState(args.seed),
                               with_memory_cache=False)

    if comm.n_procs > 1:
        train_iter = train_iter.slice(rng=None,
                                      num_of_slices=comm.n_procs,
                                      slice_pos=comm.rank)

    # Change max_iter, learning_rate and weight_decay according no. of gpu devices for multi-gpu training.
    default_batch_size = 6
    train_scale_factor = (comm.n_procs * args.batch_size) / default_batch_size

    max_iter = int(train_source._size // (comm.n_procs * args.batch_size))
    weight_decay = args.weight_decay * train_scale_factor
    args.lr = args.lr * train_scale_factor

    print(f"max_iter per GPU-device:{max_iter}")

    # Calculate the statistics (mean and variance) of the dataset
    scaler_mean, scaler_std = get_statistics(args, train_source)

    # clear cache memory
    ext.clear_memory_cache()

    # Create input variables.
    mixture_audio = nn.Variable([args.batch_size] +
                                list(train_source._get_data(0)[0].shape))
    target_audio = nn.Variable([args.batch_size] +
                               list(train_source._get_data(0)[1].shape))

    with open(f"./configs/{args.target}.yaml") as file:
        # Load target specific Hyper parameters
        hparams = yaml.load(file, Loader=yaml.FullLoader)

    # create training graph
    mix_spec = spectogram(*stft(mixture_audio,
                                n_fft=hparams['fft_size'],
                                n_hop=hparams['hop_size'],
                                patch_length=256),
                          mono=(hparams['n_channels'] == 1))
    target_spec = spectogram(*stft(target_audio,
                                   n_fft=hparams['fft_size'],
                                   n_hop=hparams['hop_size'],
                                   patch_length=256),
                             mono=(hparams['n_channels'] == 1))

    with nn.parameter_scope(args.target):
        d3net = D3NetMSS(hparams,
                         comm=comm.comm,
                         input_mean=scaler_mean,
                         input_scale=scaler_std,
                         init_method='xavier')
        pred_spec = d3net(mix_spec)

    loss = F.mean(F.squared_error(pred_spec, target_spec))
    loss.persistent = True

    # Create Solver and set parameters.
    solver = S.Adam(args.lr)
    solver.set_parameters(nn.get_parameters())

    # Initialize LR Scheduler (AnnealingScheduler)
    lr_scheduler = AnnealingScheduler(init_lr=args.lr,
                                      anneal_steps=[40],
                                      anneal_factor=0.1)

    # AverageMeter for mean loss calculation over the epoch
    losses = AverageMeter()

    for epoch in range(args.epochs):
        # TRAINING
        losses.reset()
        for batch in range(max_iter):
            mixture_audio.d, target_audio.d = train_iter.next()
            solver.zero_grad()
            loss.forward(clear_no_need_grad=True)
            if comm.n_procs > 1:
                all_reduce_callback = comm.get_all_reduce_callback()
                loss.backward(clear_buffer=True,
                              communicator_callbacks=all_reduce_callback)
            else:
                loss.backward(clear_buffer=True)
            solver.weight_decay(weight_decay)
            solver.update()
            losses.update(loss.d.copy(), args.batch_size)
        training_loss = losses.get_avg()

        # clear cache memory
        ext.clear_memory_cache()

        lr = lr_scheduler.get_learning_rate(epoch)
        solver.set_learning_rate(lr)

        if comm.rank == 0:
            monitor_traing_loss.add(epoch, training_loss)
            monitor_lr.add(epoch, lr)
            monitor_time.add(epoch)

            # save intermediate weights
            nn.save_parameters(f"{os.path.join(args.output, args.target)}.h5")

    if comm.rank == 0:
        # save final weights
        nn.save_parameters(
            f"{os.path.join(args.output, args.target)}_final.h5")