def train(c): import distiller net = Transformer(c) opt = get_opt(c, net) net, opt, step = c.init_model(net, opt=opt, step='max', train=True) step_lr = scheduler(c, opt, step) data_tr = SampleIterator(c, c.train_batch, split='valid' if c.debug else 'train') iter_tr = iter(data_tr) data_val = SequentialIterator(c, c.eval_batch, split='valid') data_test = SequentialIterator(c, c.eval_batch, split='test') print('Before quantization') tbl, sparsity = distiller.weights_sparsity_tbl_summary( net, return_total_sparsity=True) step_result = pd.Series(evaluate(c, data_val, net)).add_prefix('val_') step_result = step_result.append( pd.Series(evaluate(c, data_test, net)).add_prefix('test_')) step_result['sparsity'] = sparsity print(step_result) compression_scheduler = distiller.config.file_config(net, opt, c.compress) print('After initial quantization') s = Namespace(net=net, opt=opt, step=step) c.on_train_start(s) tbl, sparsity = distiller.weights_sparsity_tbl_summary( net, return_total_sparsity=True) step_result = pd.Series(evaluate(c, data_val, net)).add_prefix('val_') step_result = step_result.append( pd.Series(evaluate(c, data_test, net)).add_prefix('test_')) step_result['sparsity'] = sparsity print(step_result) npm = [] for name, param in net.named_parameters(): if param.dim() in [2, 4] and any(type in name for type in ['weight', 'bias']): npm.append((name, param, param.abs() == 0)) best_val_loss = np.inf if s.results is not None and 'val_loss' in s.results.columns: best_val_loss = s.results['val_loss'].dropna().max() try: steps_per_epoch = c.step_eval while step < s.step_max: epoch = step // steps_per_epoch batch = step % steps_per_epoch if batch == 0: compression_scheduler.on_epoch_begin(epoch) compression_scheduler.on_minibatch_begin(epoch, batch, steps_per_epoch) step_lr(step) x = to_torch(next(iter_tr), c.device).t() t_s = time() inputs, labels = x[:-1], x[1:] preds = net(inputs, labels) loss = preds['loss'] compression_scheduler.before_backward_pass(epoch, batch, steps_per_epoch, loss, False) opt.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(net.parameters(), c.get('clip_grad', 0.5)) compression_scheduler.before_parameter_optimization( epoch, batch, steps_per_epoch, opt) opt.step() for name, param, mask in npm: param.data[mask] = 0 compression_scheduler.on_minibatch_end(epoch, batch, steps_per_epoch) if (batch + 1) == steps_per_epoch: compression_scheduler.on_epoch_end(epoch) time_model = np.round(time() - t_s, 5) loss = from_torch(loss) perplexity = np.nan if loss > 5 else np.e**loss step_result = pd.Series( dict( loss=loss, perplexity=perplexity, time=time_model, )).add_prefix('train_') step_result['lr'] = next(iter(opt.param_groups))['lr'] s.step = step = step + 1 if step % c.step_eval == 0: tbl, sparsity = distiller.weights_sparsity_tbl_summary( net, return_total_sparsity=True) step_result = step_result.append( pd.Series(evaluate(c, data_val, net)).add_prefix('val_')) step_result = step_result.append( pd.Series(evaluate(c, data_test, net)).add_prefix('test_')) step_result['sparsity'] = sparsity s.record_step = step_result['val_loss'] < best_val_loss clear_gpu_memory() s.step_result = step_result c.on_step_end(s) except Exception as e: import traceback err = traceback.format_exc() if c.main: c.log(err) else: print(err) finally: c.on_train_end(s) return net, step
def train(c): c.setdefault(hebbian=False) net = eval(c.model)(c) emb_params = count_params(net.embed) + count_params( net.loss.projections) + count_params(net.loss.clusters) opt = get_opt(c, net) net, opt, step = c.init_model(net, opt=opt, step='max', train=True) step_lr = scheduler(c, opt, step) if c.get('distill'): data_tr_distill = DistillationSampleIterator(c, c.train_batch) iter_tr_distill = iter(data_tr_distill) else: data_tr = SampleIterator(c, c.train_batch, split='valid' if c.debug else 'train') iter_tr = iter(data_tr) data_val = SequentialIterator(c, c.eval_batch, split='valid') s = Namespace(net=net, opt=opt, step=step) c.on_train_start(s) c.log('Embedding has %s parameters' % emb_params) if c.hebbian: counters = [ torch.ones(end - start, dtype=torch.long, device=c.device) for start, end in zip([0] + c.cutoffs, c.cutoffs + [c.n_vocab]) ] temp_counters = [torch.zeros_like(x) for x in counters] best_val_loss = np.inf if s.results is not None and 'val_loss' in s.results.columns: best_val_loss = s.results['val_loss'].dropna().max() try: while step < s.step_max: step_lr(step) t_s = time() if c.get('distill'): hard_labels, soft_labels, soft_probs = next(iter_tr_distill) hard_labels = to_torch(hard_labels, c.device).t() soft_labels = to_torch(soft_labels, c.device).permute(1, 0, 2)[1:] soft_probs = to_torch(soft_probs, c.device).permute(1, 0, 2)[1:] inputs, hard_labels = hard_labels[:-1], hard_labels[1:] preds = net(inputs=inputs, labels=hard_labels, soft_labels=soft_labels, soft_probs=soft_probs, current_step=step) else: x = to_torch(next(iter_tr), c.device).t() inputs, labels = x[:-1], x[1:] preds = net(inputs, labels) loss = preds['loss'] opt.zero_grad() if torch.isnan(loss): raise RuntimeError('Encountered nan loss during training') if c.opt_level == 'O0': loss.backward() else: with amp.scale_loss(loss, opt) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_(net.parameters(), c.get('clip_grad', 0.5)) opt.step() if c.hebbian: hebbian_weight_update(c, net, preds['hiddens'], counters, temp_counters) time_model = np.round(time() - t_s, 5) loss = from_torch(loss) perplexity = np.nan if loss > 5 else np.e**loss step_result = pd.Series( dict( loss=loss, perplexity=perplexity, time=time_model, )).add_prefix('train_') step_result['lr'] = next(iter(opt.param_groups))['lr'] if c.get('use_cache'): step_result['theta'] = from_torch(preds['theta']) step_result['lambda'] = from_torch(preds['lambda']) s.step = step = step + 1 if step % c.step_eval == 0: step_result = step_result.append( pd.Series(evaluate(c, data_val, net)).add_prefix('val_')) s.record_step = step_result['val_loss'] < best_val_loss clear_gpu_memory() s.step_result = step_result c.on_step_end(s) except Exception as e: import traceback err = traceback.format_exc() if c.main: c.log(err) else: print(err) finally: c.on_train_end(s)
def train(c, net, compression_scheduler=None): import distiller.apputils as apputils from distiller.data_loggers import TensorBoardLogger, PythonLogger msglogger = apputils.config_pylogger('logging.conf', None) tflogger = TensorBoardLogger(msglogger.logdir) tflogger.log_gradients = True pylogger = PythonLogger(msglogger) c.setdefault(hebbian=False) emb_params = count_params(net.embed) + count_params(net.loss.projections) + count_params(net.loss.clusters) opt = get_opt(c, net) net, opt, step = c.init_model(net, opt=opt, step='max', train=True) step_lr = scheduler(c, opt, step) data_tr = SampleIterator(c, c.train_batch, split='valid' if c.debug else 'train') iter_tr = iter(data_tr) data_val = SequentialIterator(c, c.eval_batch, split='valid') s = Namespace(net=net, opt=opt, step=step) c.on_train_start(s) c.log('Embedding has %s parameters' % emb_params) if c.get("steps_per_epoch"): steps_per_epoch = c.steps_per_epoch else: steps_per_epoch = len(data_tr.tokens) // data_tr.bs // c.train_chunk print("#### steps per epoch %d ####" % steps_per_epoch) if c.hebbian: counters = [torch.ones(end - start, dtype=torch.long, device=c.device) for start, end in zip([0] + c.cutoffs, c.cutoffs + [c.n_vocab])] temp_counters = [torch.zeros_like(x) for x in counters] best_val_loss = np.inf if s.results is not None and 'val_loss' in s.results.columns: best_val_loss = s.results['val_loss'].dropna().max() try: while step < s.step_max: batch = step % steps_per_epoch epoch = step // steps_per_epoch if step % steps_per_epoch == 0: c.log("====> batch=%d, epoch=%d, step=%d" % (batch, epoch, step)) if compression_scheduler: compression_scheduler.on_epoch_begin(epoch) if compression_scheduler: compression_scheduler.on_minibatch_begin(epoch, minibatch_id=batch, minibatches_per_epoch=steps_per_epoch) step_lr(step) x = to_torch(next(iter_tr), c.device).t() t_s = time() inputs, labels = x[:-1], x[1:] preds = net(inputs, labels) loss = preds['loss'] if compression_scheduler: _ = compression_scheduler.before_backward_pass(epoch, minibatch_id=batch, minibatches_per_epoch=steps_per_epoch, loss=loss, return_loss_components=False) opt.zero_grad() if torch.isnan(loss): raise RuntimeError('Encountered nan loss during training') loss.backward() torch.nn.utils.clip_grad_norm_(net.parameters(), c.get('clip_grad', 0.5)) opt.step() if c.hebbian: hebbian_weight_update(c, net, preds['hiddens'], counters, temp_counters) time_model = np.round(time() - t_s, 5) loss = from_torch(loss) perplexity = np.nan if loss > 5 else np.e ** loss step_result = pd.Series(dict( loss=loss, perplexity=perplexity, time=time_model )).add_prefix('train_') step_result['lr'] = next(iter(opt.param_groups))['lr'] if c.use_cache: step_result['theta'] = preds['theta'] step_result['lambda'] = preds['lambda'].item() if compression_scheduler: compression_scheduler.on_minibatch_end(epoch, minibatch_id=batch, minibatches_per_epoch=steps_per_epoch) if step % steps_per_epoch == 0: if compression_scheduler: compression_scheduler.on_epoch_end(epoch) s.step = step = step + 1 if step % c.step_eval == 0: distiller.log_weights_sparsity(net, epoch, loggers=[tflogger, pylogger]) t, total = distiller.weights_sparsity_tbl_summary(net, return_total_sparsity=True) c.log("total sparsity: %.3lf" % total) step_result = step_result.append( pd.Series(evaluate(c, data_val, net)).add_prefix('val_') ) s.record_step = step_result['val_loss'] < best_val_loss clear_gpu_memory() s.step_result = step_result c.on_step_end(s) except Exception as e: import traceback err = traceback.format_exc() if c.main: c.log(err) else: print(err) finally: c.on_train_end(s)