def __init__(self, ext_name, device_id): super(CpuTimerCallback, self).__init__() self.results = OrderedDict() self.device_id = str(device_id) self.ext_module = import_extension_module(ext_name) self.key_to_times = OrderedDict()
def __init__(self, ext_name, device_id): """ Args: ext_name (str): backend extension name (e.g. cpu, cuda, or cudnn) device_id (str): device id """ if ext_name == "cpu": self.profiler = import_module( "nnabla.utils.inspection").CpuTimerCallback elif ext_name in ["cuda", "cudnn"]: self.profiler = import_extension_module( "cuda.utils.inspection").CudaEventTimerCallback else: # Unsupported extension. raise NotImplementedError( "Profiler for the extension '{}' is not implemented.".format( ext_name)) self.ext_name = ext_name self.device_id = device_id self._scope_name = "" self.profilers = {} self.create_new_profiler("summary")
def test_ext_utils_misc(ext_name): ext = ext_utils.import_extension_module(ext_name) ext.clear_memory_cache() if ext.get_device_count() == 0: return ds = ext.get_devices() print(ds) ext.device_synchronize(ds[0])
def profile_command(args): configure_progress(os.path.join(args.outdir, 'progress.txt')) files = [] files.append(args.config) class TrainConfig: pass config = TrainConfig() info = load.load(files) config.global_config = info.global_config config.training_config = info.training_config class OptConfig: pass config.optimizers = OrderedDict() for name, opt in info.optimizers.items(): o = OptConfig() o.optimizer = opt o.data_iterator = None config.optimizers[name] = o class MonConfig: pass config.monitors = OrderedDict() for name, mon in info.monitors.items(): m = MonConfig() m.monitor = mon m.data_iterator = None config.monitors[name] = m ext_module = import_extension_module( config.global_config.default_context.backend[0].split(':')[0]) def synchronize(): return ext_module.synchronize( device_id=config.global_config.default_context.device_id) result_array = [['time in ms']] # Profile Optimizer with ExitStack() as stack: for name, o in config.optimizers.items(): o.data_iterator = stack.enter_context(o.optimizer.data_iterator()) result_array = profile_optimizer(config, result_array, synchronize) # Write profiling result import csv with open(args.outdir + os.sep + 'profile.csv', 'w') as f: writer = csv.writer(f, lineterminator='\n') writer.writerows(result_array) logger.log(99, 'Profile Completed.') progress(None) return True
def __init__(self, graph, device_id, ext_name, solver=None, n_run=100, max_measure_execution_time=1, time_scale="m"): self.graph = graph # if solver is None, training time (forward + backward + update) is not calculated self.solver = solver self.n_run = n_run self.device_id = str(device_id) self.ext_name = ext_name self.ext_module = import_extension_module(self.ext_name) self.max_measure_execution_time = max_measure_execution_time self.time_scale = time_scale self.result = dict() self.name2val = {v: k for k, v in nn.get_parameters().items()} if self.n_run < 1: raise AssertionError("n_run must be bigger than 1")
def main(): args = get_args() rng = np.random.RandomState(1223) # Get context from nnabla.ext_utils import get_extension_context, import_extension_module logger.info("Running in %s" % args.context) ctx = get_extension_context(args.context, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) ext = import_extension_module(args.context) # read label file f = open(args.label_file_path, "r") labels_dict = f.readlines() # Load parameters _ = nn.load_parameters(args.model_load_path) # Build a Deeplab v3+ network x = nn.Variable((1, 3, args.image_height, args.image_width), need_grad=False) y = net.deeplabv3plus_model(x, args.output_stride, args.num_class, test=True) # preprocess image image = imageio.imread(args.test_image_file, as_gray=False, pilmode="RGB") #image = imread(args.test_image_file).astype('float32') orig_h, orig_w, orig_c = image.shape old_size = (orig_h, orig_w) input_array = image_preprocess.preprocess_image_and_label( image, label=None, target_width=args.image_width, target_height=args.image_height, train=False) print('Input', input_array.shape) input_array = np.transpose(input_array, (2, 0, 1)) input_array = np.reshape( input_array, (1, input_array.shape[0], input_array.shape[1], input_array.shape[2])) # Compute inference and inference time t = time.time() x.d = input_array y.forward(clear_buffer=True) print("done") available_devices = ext.get_devices() ext.device_synchronize(available_devices[0]) ext.clear_memory_cache() elapsed = time.time() - t print('Inference time : %s seconds' % (elapsed)) output = np.argmax(y.d, axis=1) # (batch,h,w) # Apply post processing post_processed = post_process(output[0], old_size, (args.image_height, args.image_width)) # Get the classes predicted predicted_classes = np.unique(post_processed) for i in range(predicted_classes.shape[0]): print('Classes Segmented: ', labels_dict[predicted_classes[i]]) # Visualize inference result visualize(post_processed)
def train(): # Check NNabla version if utils.get_nnabla_version_integer() < 11900: raise ValueError( 'Please update the nnabla version to v1.19.0 or latest version since memory efficiency of core engine is improved in v1.19.0' ) parser, args = get_train_args() # Get context. ctx = get_extension_context(args.context, device_id=args.device_id) comm = CommunicatorWrapper(ctx) nn.set_default_context(comm.ctx) ext = import_extension_module(args.context) # Monitors # setting up monitors for logging monitor_path = args.output monitor = Monitor(monitor_path) monitor_best_epoch = MonitorSeries('Best epoch', monitor, interval=1) monitor_traing_loss = MonitorSeries('Training loss', monitor, interval=1) monitor_validation_loss = MonitorSeries('Validation loss', monitor, interval=1) monitor_lr = MonitorSeries('learning rate', monitor, interval=1) monitor_time = MonitorTimeElapsed("training time per iteration", monitor, interval=1) if comm.rank == 0: print("Mixing coef. is {}, i.e., MDL = {}*TD-Loss + FD-Loss".format( args.mcoef, args.mcoef)) if not os.path.isdir(args.output): os.makedirs(args.output) # Initialize DataIterator for MUSDB. train_source, valid_source, args = load_datasources(parser, args) train_iter = data_iterator(train_source, args.batch_size, RandomState(args.seed), with_memory_cache=False, with_file_cache=False) valid_iter = data_iterator(valid_source, 1, RandomState(args.seed), with_memory_cache=False, with_file_cache=False) if comm.n_procs > 1: train_iter = train_iter.slice(rng=None, num_of_slices=comm.n_procs, slice_pos=comm.rank) valid_iter = valid_iter.slice(rng=None, num_of_slices=comm.n_procs, slice_pos=comm.rank) # Calculate maxiter per GPU device. max_iter = int((train_source._size // args.batch_size) // comm.n_procs) weight_decay = args.weight_decay * comm.n_procs print("max_iter", max_iter) # Calculate the statistics (mean and variance) of the dataset scaler_mean, scaler_std = utils.get_statistics(args, train_source) max_bin = utils.bandwidth_to_max_bin(train_source.sample_rate, args.nfft, args.bandwidth) unmix = OpenUnmix_CrossNet(input_mean=scaler_mean, input_scale=scaler_std, nb_channels=args.nb_channels, hidden_size=args.hidden_size, n_fft=args.nfft, n_hop=args.nhop, max_bin=max_bin) # Create input variables. mixture_audio = nn.Variable([args.batch_size] + list(train_source._get_data(0)[0].shape)) target_audio = nn.Variable([args.batch_size] + list(train_source._get_data(0)[1].shape)) vmixture_audio = nn.Variable( [1] + [2, valid_source.sample_rate * args.valid_dur]) vtarget_audio = nn.Variable([1] + [8, valid_source.sample_rate * args.valid_dur]) # create training graph mix_spec, M_hat, pred = unmix(mixture_audio) Y = Spectrogram(*STFT(target_audio, n_fft=unmix.n_fft, n_hop=unmix.n_hop), mono=(unmix.nb_channels == 1)) loss_f = mse_loss(mix_spec, M_hat, Y) loss_t = sdr_loss(mixture_audio, pred, target_audio) loss = args.mcoef * loss_t + loss_f loss.persistent = True # Create Solver and set parameters. solver = S.Adam(args.lr) solver.set_parameters(nn.get_parameters()) # create validation graph vmix_spec, vM_hat, vpred = unmix(vmixture_audio, test=True) vY = Spectrogram(*STFT(vtarget_audio, n_fft=unmix.n_fft, n_hop=unmix.n_hop), mono=(unmix.nb_channels == 1)) vloss_f = mse_loss(vmix_spec, vM_hat, vY) vloss_t = sdr_loss(vmixture_audio, vpred, vtarget_audio) vloss = args.mcoef * vloss_t + vloss_f vloss.persistent = True # Initialize Early Stopping es = utils.EarlyStopping(patience=args.patience) # Initialize LR Scheduler (ReduceLROnPlateau) lr_scheduler = ReduceLROnPlateau(lr=args.lr, factor=args.lr_decay_gamma, patience=args.lr_decay_patience) best_epoch = 0 # Training loop. for epoch in trange(args.epochs): # TRAINING losses = utils.AverageMeter() for batch in range(max_iter): mixture_audio.d, target_audio.d = train_iter.next() solver.zero_grad() loss.forward(clear_no_need_grad=True) if comm.n_procs > 1: all_reduce_callback = comm.get_all_reduce_callback() loss.backward(clear_buffer=True, communicator_callbacks=all_reduce_callback) else: loss.backward(clear_buffer=True) solver.weight_decay(weight_decay) solver.update() losses.update(loss.d.copy(), args.batch_size) training_loss = losses.avg # clear cache memory ext.clear_memory_cache() # VALIDATION vlosses = utils.AverageMeter() for batch in range(int(valid_source._size // comm.n_procs)): x, y = valid_iter.next() dur = int(valid_source.sample_rate * args.valid_dur) sp, cnt = 0, 0 loss_tmp = nn.NdArray() loss_tmp.zero() while 1: vmixture_audio.d = x[Ellipsis, sp:sp + dur] vtarget_audio.d = y[Ellipsis, sp:sp + dur] vloss.forward(clear_no_need_grad=True) cnt += 1 sp += dur loss_tmp += vloss.data if x[Ellipsis, sp:sp + dur].shape[-1] < dur or x.shape[-1] == cnt * dur: break loss_tmp = loss_tmp / cnt if comm.n_procs > 1: comm.all_reduce(loss_tmp, division=True, inplace=True) vlosses.update(loss_tmp.data.copy(), 1) validation_loss = vlosses.avg # clear cache memory ext.clear_memory_cache() lr = lr_scheduler.update_lr(validation_loss, epoch=epoch) solver.set_learning_rate(lr) stop = es.step(validation_loss) if comm.rank == 0: monitor_best_epoch.add(epoch, best_epoch) monitor_traing_loss.add(epoch, training_loss) monitor_validation_loss.add(epoch, validation_loss) monitor_lr.add(epoch, lr) monitor_time.add(epoch) if validation_loss == es.best: # save best model nn.save_parameters(os.path.join(args.output, 'best_xumx.h5')) best_epoch = epoch if stop: print("Apply Early Stopping") break
def test_lms(type_config, device_id, batch_size, num_dilations, learning_rate, max_iter, gpu_memory_size, max_prefetch_bytes, cast_prefetch, memory): import nnabla_ext.cuda.init as cuda_init # Use pinned host memory cuda_init.prefer_cpu_pinned_array() # Change a type of memory allocator for device if memory == "virtual": cuda_init.prefer_cuda_virtual_array() from nnabla_ext.cuda.init import set_cuda_virtual_memory_chunk_size set_cuda_virtual_memory_chunk_size(2 << 20) elif memory == "cached": cuda_init.prefer_cuda_cached_array() # Set context. from nnabla.ext_utils import get_extension_context cpu_ctx = get_extension_context('cpu', device_id='', type_config=type_config) gpu_ctx = get_extension_context('cudnn', device_id=device_id, type_config=type_config) nn.set_default_context(gpu_ctx) # Input data x0 = [] t0 = [] with tempfile.NamedTemporaryFile(mode='w', suffix='.h5', delete=False) as tf: # Init inputs np.random.seed(seed=32) for i in range(max_iter): x0 += [ np.random.randint(0, 256, size=(batch_size, data_config.duration, 1)) ] t0 += [ np.random.randint(0, 256, size=(batch_size, data_config.duration, 1)) ] initial_param1 = [] initial_param2 = [] ############################### # Normal ############################### with nn.parameter_scope('network1'): # Create network x, t, loss, solver = create_network(batch_size, num_dilations, learning_rate) # Store the initial parameter nn.save_parameters(tf.name) # Load the initial parameter nn.load_parameters(tf.name) for k, p in nn.get_parameters(grad_only=False).items(): initial_param1.append(p.d) # Training loop. for i in range(max_iter): x.d = x0[i] t.d = t0[i] loss.forward(clear_no_need_grad=True) solver.zero_grad() loss.backward(clear_buffer=True) solver.update() # Synchronization ext = ext_utils.import_extension_module('cudnn') ext.device_synchronize(device_id) ############################### # Swap in/out scheduler ############################### with nn.parameter_scope('network2'): # Create network x, t, loss, solver = create_network(batch_size, num_dilations, learning_rate) # Load the initial parameter nn.load_parameters(tf.name) for k, p in nn.get_parameters(grad_only=False).items(): initial_param2.append(p.d) # Create a scheduler scheduler = lms.SwapInOutScheduler(cpu_ctx, gpu_ctx, gpu_memory_size, max_prefetch_bytes, cast_prefetch) # Training loop. for i in range(max_iter): with scheduler: x.d = x0[i] t.d = t0[i] loss.forward(clear_no_need_grad=True) solver.zero_grad() loss.backward(clear_buffer=True) solver.update() # Synchronization ext = ext_utils.import_extension_module('cudnn') ext.device_synchronize(device_id) ############################### # Test ############################### for p1, p2 in zip(initial_param1, initial_param2): # Check the identity of initial parameters assert np.array_equal(p1, p2) with nn.parameter_scope('network1'): param1 = nn.get_parameters(grad_only=False) with nn.parameter_scope('network2'): param2 = nn.get_parameters(grad_only=False) for (k1, p1), (k2, p2) in zip(param1.items(), param2.items()): assert_allclose(p1.d, p2.d, atol=2e-3, rtol=2e-3) # Remove the file tf.close() os.remove(tf.name) assert not os.path.exists(tf.name)
def profile_command(args): callback.update_status(args) configure_progress(os.path.join(args.outdir, 'progress.txt')) class TrainConfig: pass config = TrainConfig() info = load.load(args.config) config.global_config = info.global_config config.training_config = info.training_config class OptConfig: pass config.optimizers = OrderedDict() for name, opt in info.optimizers.items(): o = OptConfig() o.optimizer = opt o.data_iterators = [] config.optimizers[name] = o class MonConfig: pass config.monitors = OrderedDict() for name, mon in info.monitors.items(): m = MonConfig() m.monitor = mon m.data_iterators = [] config.monitors[name] = m ext_module = import_extension_module( config.global_config.default_context.backend[0].split(':')[0]) def synchronize(): return ext_module.synchronize( device_id=config.global_config.default_context.device_id) result_array = [['time in ms']] callback.update_status('processing', True) # Profile Optimizer with ExitStack() as stack: # Create data_iterator instance only once for each dataset in optimizers optimizer_data_iterators = {} for name, o in config.optimizers.items(): for di in o.optimizer.data_iterators.values(): if di not in optimizer_data_iterators: di_instance = stack.enter_context(di()) optimizer_data_iterators[di] = di_instance else: di_instance = optimizer_data_iterators[di] o.data_iterators.append(di_instance) result_array = profile_optimizer(config, result_array, synchronize) # Write profiling result import csv with open(args.outdir + os.sep + 'profile.csv', 'w') as f: writer = csv.writer(f, lineterminator='\n') writer.writerows(result_array) logger.log(99, 'Profile Completed.') progress(None) callback.update_status('finished') return True
def test_import_extension_module(ext_name): ext = ext_utils.import_extension_module(ext_name)
scores = museval.eval_mus_track(track, estimates, output_dir=args.out_dir) # clear cache memory ext.clear_memory_cache() return scores if __name__ == '__main__': # Get the arguments parser args = get_inference_args() # Set NNabla context and Dynamic graph execution ctx = get_extension_context(args.context) nn.set_default_context(ctx) nn.set_auto_forward(True) ext = import_extension_module(args.context) mus = musdb.DB(root=args.root, download=args.root is None, subsets='test', is_wav=args.is_wav) if args.cores > 1: pool = multiprocessing.Pool(args.cores) results = museval.EvalStore() scores_list = list( pool.imap_unordered(func=functools.partial(separate_and_evaluate, args, ext), iterable=mus.tracks, chunksize=1)) pool.close()
def train(): # Check NNabla version if utils.get_nnabla_version_integer() < 11900: raise ValueError( 'Please update the nnabla version to v1.19.0 or latest version since memory efficiency of core engine is improved in v1.19.0' ) parser, args = get_train_args() # Get context. ctx = get_extension_context(args.context, device_id=args.device_id) comm = CommunicatorWrapper(ctx) nn.set_default_context(comm.ctx) ext = import_extension_module(args.context) # Monitors # setting up monitors for logging monitor_path = args.output monitor = Monitor(monitor_path) monitor_best_epoch = MonitorSeries('Best epoch', monitor, interval=1) monitor_traing_loss = MonitorSeries('Training loss', monitor, interval=1) monitor_validation_loss = MonitorSeries('Validation loss', monitor, interval=1) monitor_lr = MonitorSeries('learning rate', monitor, interval=1) monitor_time = MonitorTimeElapsed("training time per iteration", monitor, interval=1) if comm.rank == 0: if not os.path.isdir(args.output): os.makedirs(args.output) # Initialize DataIterator for MUSDB18. train_source, valid_source, args = load_datasources(parser, args) train_iter = data_iterator( train_source, args.batch_size, RandomState(args.seed), with_memory_cache=False, ) valid_iter = data_iterator( valid_source, 1, RandomState(args.seed), with_memory_cache=False, ) if comm.n_procs > 1: train_iter = train_iter.slice(rng=None, num_of_slices=comm.n_procs, slice_pos=comm.rank) valid_iter = valid_iter.slice(rng=None, num_of_slices=comm.n_procs, slice_pos=comm.rank) # Calculate maxiter per GPU device. # Change max_iter, learning_rate and weight_decay according no. of gpu devices for multi-gpu training. default_batch_size = 16 train_scale_factor = (comm.n_procs * args.batch_size) / default_batch_size max_iter = int((train_source._size // args.batch_size) // comm.n_procs) weight_decay = args.weight_decay * train_scale_factor args.lr = args.lr * train_scale_factor # Calculate the statistics (mean and variance) of the dataset scaler_mean, scaler_std = utils.get_statistics(args, train_source) # clear cache memory ext.clear_memory_cache() max_bin = utils.bandwidth_to_max_bin(train_source.sample_rate, args.nfft, args.bandwidth) # Get X-UMX/UMX computation graph and variables as namedtuple model = get_model(args, scaler_mean, scaler_std, max_bin=max_bin) # Create Solver and set parameters. solver = S.Adam(args.lr) solver.set_parameters(nn.get_parameters()) # Initialize Early Stopping es = utils.EarlyStopping(patience=args.patience) # Initialize LR Scheduler (ReduceLROnPlateau) lr_scheduler = ReduceLROnPlateau(lr=args.lr, factor=args.lr_decay_gamma, patience=args.lr_decay_patience) best_epoch = 0 # AverageMeter for mean loss calculation over the epoch losses = utils.AverageMeter() # Training loop. for epoch in trange(args.epochs): # TRAINING losses.reset() for batch in range(max_iter): model.mixture_audio.d, model.target_audio.d = train_iter.next() solver.zero_grad() model.loss.forward(clear_no_need_grad=True) if comm.n_procs > 1: all_reduce_callback = comm.get_all_reduce_callback() model.loss.backward(clear_buffer=True, communicator_callbacks=all_reduce_callback) else: model.loss.backward(clear_buffer=True) solver.weight_decay(weight_decay) solver.update() losses.update(model.loss.d.copy(), args.batch_size) training_loss = losses.get_avg() # clear cache memory ext.clear_memory_cache() # VALIDATION losses.reset() for batch in range(int(valid_source._size // comm.n_procs)): x, y = valid_iter.next() dur = int(valid_source.sample_rate * args.valid_dur) sp, cnt = 0, 0 loss_tmp = nn.NdArray() loss_tmp.zero() while 1: model.vmixture_audio.d = x[Ellipsis, sp:sp + dur] model.vtarget_audio.d = y[Ellipsis, sp:sp + dur] model.vloss.forward(clear_no_need_grad=True) cnt += 1 sp += dur loss_tmp += model.vloss.data if x[Ellipsis, sp:sp + dur].shape[-1] < dur or x.shape[-1] == cnt * dur: break loss_tmp = loss_tmp / cnt if comm.n_procs > 1: comm.all_reduce(loss_tmp, division=True, inplace=True) losses.update(loss_tmp.data.copy(), 1) validation_loss = losses.get_avg() # clear cache memory ext.clear_memory_cache() lr = lr_scheduler.update_lr(validation_loss, epoch=epoch) solver.set_learning_rate(lr) stop = es.step(validation_loss) if comm.rank == 0: monitor_best_epoch.add(epoch, best_epoch) monitor_traing_loss.add(epoch, training_loss) monitor_validation_loss.add(epoch, validation_loss) monitor_lr.add(epoch, lr) monitor_time.add(epoch) if validation_loss == es.best: best_epoch = epoch # save best model if args.umx_train: nn.save_parameters(os.path.join(args.output, 'best_umx.h5')) else: nn.save_parameters( os.path.join(args.output, 'best_xumx.h5')) if args.umx_train: # Early stopping for UMX after `args.patience` (140) number of epochs if stop: print("Apply Early Stopping") break
def train(): # Check NNabla version if get_nnabla_version_integer() < 11900: raise ValueError( 'Please update the nnabla version to v1.19.0 or latest version since memory efficiency of core engine is improved in v1.19.0' ) parser, args = get_train_args() # Get context. ctx = get_extension_context(args.context, device_id=args.device_id) comm = CommunicatorWrapper(ctx) nn.set_default_context(comm.ctx) ext = import_extension_module(args.context) # Monitors # setting up monitors for logging monitor_path = os.path.join(args.output, args.target) monitor = Monitor(monitor_path) monitor_traing_loss = MonitorSeries('Training loss', monitor, interval=1) monitor_lr = MonitorSeries('learning rate', monitor, interval=1) monitor_time = MonitorTimeElapsed("training time per epoch", monitor, interval=1) if comm.rank == 0: if not os.path.isdir(args.output): os.makedirs(args.output) # Initialize DataIterator for MUSDB. train_source, args = load_datasources(parser, args) train_iter = data_iterator(train_source, args.batch_size, RandomState(args.seed), with_memory_cache=False) if comm.n_procs > 1: train_iter = train_iter.slice(rng=None, num_of_slices=comm.n_procs, slice_pos=comm.rank) # Change max_iter, learning_rate and weight_decay according no. of gpu devices for multi-gpu training. default_batch_size = 6 train_scale_factor = (comm.n_procs * args.batch_size) / default_batch_size max_iter = int(train_source._size // (comm.n_procs * args.batch_size)) weight_decay = args.weight_decay * train_scale_factor args.lr = args.lr * train_scale_factor print(f"max_iter per GPU-device:{max_iter}") # Calculate the statistics (mean and variance) of the dataset scaler_mean, scaler_std = get_statistics(args, train_source) # clear cache memory ext.clear_memory_cache() # Create input variables. mixture_audio = nn.Variable([args.batch_size] + list(train_source._get_data(0)[0].shape)) target_audio = nn.Variable([args.batch_size] + list(train_source._get_data(0)[1].shape)) with open(f"./configs/{args.target}.yaml") as file: # Load target specific Hyper parameters hparams = yaml.load(file, Loader=yaml.FullLoader) # create training graph mix_spec = spectogram(*stft(mixture_audio, n_fft=hparams['fft_size'], n_hop=hparams['hop_size'], patch_length=256), mono=(hparams['n_channels'] == 1)) target_spec = spectogram(*stft(target_audio, n_fft=hparams['fft_size'], n_hop=hparams['hop_size'], patch_length=256), mono=(hparams['n_channels'] == 1)) with nn.parameter_scope(args.target): d3net = D3NetMSS(hparams, comm=comm.comm, input_mean=scaler_mean, input_scale=scaler_std, init_method='xavier') pred_spec = d3net(mix_spec) loss = F.mean(F.squared_error(pred_spec, target_spec)) loss.persistent = True # Create Solver and set parameters. solver = S.Adam(args.lr) solver.set_parameters(nn.get_parameters()) # Initialize LR Scheduler (AnnealingScheduler) lr_scheduler = AnnealingScheduler(init_lr=args.lr, anneal_steps=[40], anneal_factor=0.1) # AverageMeter for mean loss calculation over the epoch losses = AverageMeter() for epoch in range(args.epochs): # TRAINING losses.reset() for batch in range(max_iter): mixture_audio.d, target_audio.d = train_iter.next() solver.zero_grad() loss.forward(clear_no_need_grad=True) if comm.n_procs > 1: all_reduce_callback = comm.get_all_reduce_callback() loss.backward(clear_buffer=True, communicator_callbacks=all_reduce_callback) else: loss.backward(clear_buffer=True) solver.weight_decay(weight_decay) solver.update() losses.update(loss.d.copy(), args.batch_size) training_loss = losses.get_avg() # clear cache memory ext.clear_memory_cache() lr = lr_scheduler.get_learning_rate(epoch) solver.set_learning_rate(lr) if comm.rank == 0: monitor_traing_loss.add(epoch, training_loss) monitor_lr.add(epoch, lr) monitor_time.add(epoch) # save intermediate weights nn.save_parameters(f"{os.path.join(args.output, args.target)}.h5") if comm.rank == 0: # save final weights nn.save_parameters( f"{os.path.join(args.output, args.target)}_final.h5")