def main(): opt = parse_option() torch.cuda.set_device(opt.gpu) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = True encoder = SmallAlexNet(feat_dim=opt.feat_dim).to(opt.gpu) encoder.eval() train_loader, val_loader = get_data_loaders(opt) with torch.no_grad(): sample, _ = train_loader.dataset[0] eval_numel = encoder(sample.unsqueeze(0).to(opt.gpu), layer_index=opt.layer_index).numel() print(f'Feature dimension: {eval_numel}') encoder.load_state_dict( torch.load(opt.encoder_checkpoint, map_location=opt.gpu)) print(f'Loaded checkpoint from {opt.encoder_checkpoint}') classifier = nn.Linear(eval_numel, 10).to(opt.gpu) optim = torch.optim.Adam(classifier.parameters(), lr=opt.lr, betas=(0.5, 0.999)) scheduler = torch.optim.lr_scheduler.MultiStepLR( optim, gamma=opt.lr_decay_rate, milestones=opt.lr_decay_epochs) loss_meter = AverageMeter('loss') it_time_meter = AverageMeter('iter_time') for epoch in range(opt.epochs): loss_meter.reset() it_time_meter.reset() t0 = time.time() for ii, (images, labels) in enumerate(train_loader): optim.zero_grad() with torch.no_grad(): feats = encoder(images.to(opt.gpu), layer_index=opt.layer_index).flatten(1) logits = classifier(feats) loss = F.cross_entropy(logits, labels.to(opt.gpu)) loss_meter.update(loss, images.shape[0]) loss.backward() optim.step() it_time_meter.update(time.time() - t0) if ii % opt.log_interval == 0: print( f"Epoch {epoch}/{opt.epochs}\tIt {ii}/{len(train_loader)}\t{loss_meter}\t{it_time_meter}" ) t0 = time.time() scheduler.step() val_acc = validate(opt, encoder, classifier, val_loader) print(f"Epoch {epoch}/{opt.epochs}\tval_acc {val_acc*100:.4g}%")
def main(): opt = parse_option() print( f'Optimize: {opt.align_w:g} * loss_align(alpha={opt.align_alpha:g}) + {opt.unif_w:g} * loss_uniform(t={opt.unif_t:g})' ) torch.cuda.set_device(opt.gpus[0]) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = True encoder = nn.DataParallel( SmallAlexNet(feat_dim=opt.feat_dim).to(opt.gpus[0]), opt.gpus) optim = torch.optim.SGD(encoder.parameters(), lr=opt.lr, momentum=opt.momentum, weight_decay=opt.weight_decay) scheduler = torch.optim.lr_scheduler.MultiStepLR( optim, gamma=opt.lr_decay_rate, milestones=opt.lr_decay_epochs) loader = get_data_loader(opt) align_meter = AverageMeter('align_loss') unif_meter = AverageMeter('uniform_loss') loss_meter = AverageMeter('total_loss') it_time_meter = AverageMeter('iter_time') for epoch in range(opt.epochs): align_meter.reset() unif_meter.reset() loss_meter.reset() it_time_meter.reset() t0 = time.time() for ii, (im_x, im_y) in enumerate(loader): optim.zero_grad() x, y = encoder( torch.cat([im_x.to(opt.gpus[0]), im_y.to(opt.gpus[0])])).chunk(2) align_loss_val = align_loss(x, y, alpha=opt.align_alpha) unif_loss_val = (uniform_loss(x, t=opt.unif_t) + uniform_loss(y, t=opt.unif_t)) / 2 loss = align_loss_val * opt.align_w + unif_loss_val * opt.unif_w align_meter.update(align_loss_val, x.shape[0]) unif_meter.update(unif_loss_val) loss_meter.update(loss, x.shape[0]) loss.backward() optim.step() it_time_meter.update(time.time() - t0) if ii % opt.log_interval == 0: print( f"Epoch {epoch}/{opt.epochs}\tIt {ii}/{len(loader)}\t" + f"{align_meter}\t{unif_meter}\t{loss_meter}\t{it_time_meter}" ) t0 = time.time() scheduler.step() ckpt_file = os.path.join(opt.save_folder, 'encoder.pth') torch.save(encoder.module.state_dict(), ckpt_file) print(f'Saved to {ckpt_file}')
class Evaluator(): def __init__(self, data_loader, logger, config, name='Evaluator', metrics='classfication', summary_writer=None): self.data_loader = data_loader self.logger = logger self.name = name self.summary_writer = summary_writer self.step = 0 self.config = config self.log_frequency = config.log_frequency self.loss_meters = AverageMeter() self.acc_meters = AverageMeter() self.acc5_meters = AverageMeter() self.report_metrics = self.classfication_metrics if metrics == 'classfication' else self.regression_metrics return def log(self, epoch, GLOBAL_STEP): display = log_display(epoch=epoch, global_step=GLOBAL_STEP, time_elapse=self.time_used, **self.logger_payload) self.logger.info(display) def eval(self, epoch, GLOBAL_STEP, model, criterion): for i, (images, labels) in enumerate(self.data_loader): self.eval_batch(x=images, y=labels, model=model, criterion=criterion) self.log(epoch, GLOBAL_STEP) return def eval_batch(self, x, y, model, criterion): model.eval() x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True) start = time.time() with torch.no_grad(): pred = model(x) loss = criterion(pred, y) end = time.time() self.time_used = end - start self.step += 1 self.report_metrics(pred, y, loss) return def classfication_metrics(self, x, y, loss): acc, acc5 = accuracy(x, y, topk=(1, 5)) self.loss_meters.update(loss.item(), y.shape[0]) self.acc_meters.update(acc.item(), y.shape[0]) self.acc5_meters.update(acc5.item(), y.shape[0]) self.logger_payload = { "acc": acc, "acc_avg": self.acc_meters.avg, "top5_acc": acc5, "top5_acc_avg": self.acc5_meters.avg, "loss": loss, "loss_avg": self.loss_meters.avg } if self.summary_writer is not None: self.summary_writer.add_scalar(os.path.join(self.name, 'acc'), acc, self.step) self.summary_writer.add_scalar(os.path.join(self.name, 'loss'), loss, self.step) def regression_metrics(self, x, y, loss): diff = abs((x - y).mean().detach().item()) self.loss_meters.update(loss.item(), y.shape[0]) self.acc_meters.update(diff, y.shape[0]) self.logger_payload = { "|diff|": diff, "|diff_avg|": self.acc_meters.avg, "loss": loss, "loss_avg": self.loss_meters.avg } if self.summary_writer is not None: self.summary_writer.add_scalar(os.path.join(self.name, 'diff'), diff, self.step) self.summary_writer.add_scalar(os.path.join(self.name, 'loss'), loss, self.step) def _reset_stats(self): self.loss_meters.reset() self.acc_meters.reset() self.acc5_meters.reset()
class Trainer(): def __init__(self, data_loader, logger, config, name='Trainer', metrics='classfication'): self.data_loader = data_loader self.logger = logger self.name = name self.step = 0 self.config = config self.log_frequency = config.log_frequency self.loss_meters = AverageMeter() self.acc_meters = AverageMeter() self.acc5_meters = AverageMeter() self.report_metrics = self.classfication_metrics if metrics == 'classfication' else self.regression_metrics def train(self, epoch, GLOBAL_STEP, model, optimizer, criterion): model.train() for images, labels in self.data_loader: images, labels = images.to(device, non_blocking=True), labels.to( device, non_blocking=True) self.train_batch(images, labels, model, criterion, optimizer) self.log(epoch, GLOBAL_STEP) GLOBAL_STEP += 1 return GLOBAL_STEP def train_batch(self, x, y, model, criterion, optimizer): start = time.time() model.zero_grad() optimizer.zero_grad() pred = model(x) loss = criterion(pred, y) loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), self.config.grad_bound) optimizer.step() self.report_metrics(pred, y, loss) self.logger_payload['lr'] = optimizer.param_groups[0]['lr'], self.logger_payload['|gn|'] = grad_norm end = time.time() self.step += 1 self.time_used = end - start def log(self, epoch, GLOBAL_STEP): if GLOBAL_STEP % self.log_frequency == 0: display = log_display(epoch=epoch, global_step=GLOBAL_STEP, time_elapse=self.time_used, **self.logger_payload) self.logger.info(display) def classfication_metrics(self, x, y, loss): acc, acc5 = accuracy(x, y, topk=(1, 5)) self.loss_meters.update(loss.item(), y.shape[0]) self.acc_meters.update(acc.item(), y.shape[0]) self.acc5_meters.update(acc5.item(), y.shape[0]) self.logger_payload = { "acc": acc, "acc_avg": self.acc_meters.avg, "loss": loss, "loss_avg": self.loss_meters.avg } def regression_metrics(self, x, y, loss): diff = abs((x - y).mean().detach().item()) self.loss_meters.update(loss.item(), y.shape[0]) self.acc_meters.update(diff, y.shape[0]) self.logger_payload = { "|diff|": diff, "|diff_avg|": self.acc_meters.avg, "loss": loss, "loss_avg": self.loss_meters.avg } def _reset_stats(self): self.loss_meters.reset() self.acc_meters.reset() self.acc5_meters.reset()
def train(): # Check NNabla version if get_nnabla_version_integer() < 11900: raise ValueError( 'Please update the nnabla version to v1.19.0 or latest version since memory efficiency of core engine is improved in v1.19.0' ) parser, args = get_train_args() # Get context. ctx = get_extension_context(args.context, device_id=args.device_id) comm = CommunicatorWrapper(ctx) nn.set_default_context(comm.ctx) ext = import_extension_module(args.context) # Monitors # setting up monitors for logging monitor_path = os.path.join(args.output, args.target) monitor = Monitor(monitor_path) monitor_traing_loss = MonitorSeries('Training loss', monitor, interval=1) monitor_lr = MonitorSeries('learning rate', monitor, interval=1) monitor_time = MonitorTimeElapsed("training time per epoch", monitor, interval=1) if comm.rank == 0: if not os.path.isdir(args.output): os.makedirs(args.output) # Initialize DataIterator for MUSDB. train_source, args = load_datasources(parser, args) train_iter = data_iterator(train_source, args.batch_size, RandomState(args.seed), with_memory_cache=False) if comm.n_procs > 1: train_iter = train_iter.slice(rng=None, num_of_slices=comm.n_procs, slice_pos=comm.rank) # Change max_iter, learning_rate and weight_decay according no. of gpu devices for multi-gpu training. default_batch_size = 6 train_scale_factor = (comm.n_procs * args.batch_size) / default_batch_size max_iter = int(train_source._size // (comm.n_procs * args.batch_size)) weight_decay = args.weight_decay * train_scale_factor args.lr = args.lr * train_scale_factor print(f"max_iter per GPU-device:{max_iter}") # Calculate the statistics (mean and variance) of the dataset scaler_mean, scaler_std = get_statistics(args, train_source) # clear cache memory ext.clear_memory_cache() # Create input variables. mixture_audio = nn.Variable([args.batch_size] + list(train_source._get_data(0)[0].shape)) target_audio = nn.Variable([args.batch_size] + list(train_source._get_data(0)[1].shape)) with open(f"./configs/{args.target}.yaml") as file: # Load target specific Hyper parameters hparams = yaml.load(file, Loader=yaml.FullLoader) # create training graph mix_spec = spectogram(*stft(mixture_audio, n_fft=hparams['fft_size'], n_hop=hparams['hop_size'], patch_length=256), mono=(hparams['n_channels'] == 1)) target_spec = spectogram(*stft(target_audio, n_fft=hparams['fft_size'], n_hop=hparams['hop_size'], patch_length=256), mono=(hparams['n_channels'] == 1)) with nn.parameter_scope(args.target): d3net = D3NetMSS(hparams, comm=comm.comm, input_mean=scaler_mean, input_scale=scaler_std, init_method='xavier') pred_spec = d3net(mix_spec) loss = F.mean(F.squared_error(pred_spec, target_spec)) loss.persistent = True # Create Solver and set parameters. solver = S.Adam(args.lr) solver.set_parameters(nn.get_parameters()) # Initialize LR Scheduler (AnnealingScheduler) lr_scheduler = AnnealingScheduler(init_lr=args.lr, anneal_steps=[40], anneal_factor=0.1) # AverageMeter for mean loss calculation over the epoch losses = AverageMeter() for epoch in range(args.epochs): # TRAINING losses.reset() for batch in range(max_iter): mixture_audio.d, target_audio.d = train_iter.next() solver.zero_grad() loss.forward(clear_no_need_grad=True) if comm.n_procs > 1: all_reduce_callback = comm.get_all_reduce_callback() loss.backward(clear_buffer=True, communicator_callbacks=all_reduce_callback) else: loss.backward(clear_buffer=True) solver.weight_decay(weight_decay) solver.update() losses.update(loss.d.copy(), args.batch_size) training_loss = losses.get_avg() # clear cache memory ext.clear_memory_cache() lr = lr_scheduler.get_learning_rate(epoch) solver.set_learning_rate(lr) if comm.rank == 0: monitor_traing_loss.add(epoch, training_loss) monitor_lr.add(epoch, lr) monitor_time.add(epoch) # save intermediate weights nn.save_parameters(f"{os.path.join(args.output, args.target)}.h5") if comm.rank == 0: # save final weights nn.save_parameters( f"{os.path.join(args.output, args.target)}_final.h5")