def evaluate_cache_search(config, net): opt = get_opt(config, net) net, opt, step = config.init_model(net, opt=opt, step='max', train=True) distiller.model_summary(net, "sparsity", 'wikitext-103') perplexity = {} # search best cache hyperparamters on validation data_val = SequentialIterator(config,config.eval_batch, split="valid") nocache_ppl = evaluate(config, data_val, net) config.log("nocahce val ppl: %s" % nocache_ppl) thetas = [2e-2, 1e-2, 9e-3, 8e-3, 7e-3, 6e-3, 5e-3, 4e-3, 3e-3, 2e-3, 1e-3] lambdas = [0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1] thetas = thetas[:5] lambdas = lambdas[3:9] best_theta = -1 best_lambda = -1 best_ppl = 1000000 data_test = SequentialIterator(config, config.eval_batch, split="test") for theta in thetas: for lam in lambdas: if (theta, lam) in perplexity: continue try: net.loss.cache_keys = net.loss.cache_values = None except: net.module.loss.cache_keys = net.module.loss.cache_values = None perplexity[theta, lam] = evaluate(config.var(use_cache=True, n_cache=2000, cache_theta=theta, cache_lambda=lam), data_val, net)['perplexity'] print("ppl theta=", theta," lam=", lam, "perpelxity=", perplexity[theta, lam]) eval_output = evaluate(config.var(use_cache=True, n_cache=2000, cache_thetaa=best_theta, cache_lambda=best_lambda), data_test, net) config.log("TEST RESULT: %s" % eval_output) if perplexity[theta, lam] < best_ppl: best_theta = theta best_lambda = lam best_ppl = perplexity[theta, lam] # evaluate on test data_test = SequentialIterator(config, config.eval_batch, split="test") print("Final Evaluation") distiller.model_summary(net, "sparsity", 'wikitext-103') eval_output = evaluate(config.var(use_cache=True, n_cache=2000, cache_thetaa=best_theta, cache_lambda=best_lambda), data_test, net) config.log("VAL RESULT: ppl(%.3lf) theta(%.3lf) lambda(%.3lf)" % (best_ppl, best_theta, best_lambda)) config.log("TEST RESULT: %s" % eval_output) return eval_output
def train(c): c.setdefault(hebbian=False) net = eval(c.model)(c) emb_params = count_params(net.embed) + count_params( net.loss.projections) + count_params(net.loss.clusters) opt = get_opt(c, net) net, opt, step = c.init_model(net, opt=opt, step='max', train=True) step_lr = scheduler(c, opt, step) if c.get('distill'): data_tr_distill = DistillationSampleIterator(c, c.train_batch) iter_tr_distill = iter(data_tr_distill) else: data_tr = SampleIterator(c, c.train_batch, split='valid' if c.debug else 'train') iter_tr = iter(data_tr) data_val = SequentialIterator(c, c.eval_batch, split='valid') s = Namespace(net=net, opt=opt, step=step) c.on_train_start(s) c.log('Embedding has %s parameters' % emb_params) if c.hebbian: counters = [ torch.ones(end - start, dtype=torch.long, device=c.device) for start, end in zip([0] + c.cutoffs, c.cutoffs + [c.n_vocab]) ] temp_counters = [torch.zeros_like(x) for x in counters] best_val_loss = np.inf if s.results is not None and 'val_loss' in s.results.columns: best_val_loss = s.results['val_loss'].dropna().max() try: while step < s.step_max: step_lr(step) t_s = time() if c.get('distill'): hard_labels, soft_labels, soft_probs = next(iter_tr_distill) hard_labels = to_torch(hard_labels, c.device).t() soft_labels = to_torch(soft_labels, c.device).permute(1, 0, 2)[1:] soft_probs = to_torch(soft_probs, c.device).permute(1, 0, 2)[1:] inputs, hard_labels = hard_labels[:-1], hard_labels[1:] preds = net(inputs=inputs, labels=hard_labels, soft_labels=soft_labels, soft_probs=soft_probs, current_step=step) else: x = to_torch(next(iter_tr), c.device).t() inputs, labels = x[:-1], x[1:] preds = net(inputs, labels) loss = preds['loss'] opt.zero_grad() if torch.isnan(loss): raise RuntimeError('Encountered nan loss during training') if c.opt_level == 'O0': loss.backward() else: with amp.scale_loss(loss, opt) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_(net.parameters(), c.get('clip_grad', 0.5)) opt.step() if c.hebbian: hebbian_weight_update(c, net, preds['hiddens'], counters, temp_counters) time_model = np.round(time() - t_s, 5) loss = from_torch(loss) perplexity = np.nan if loss > 5 else np.e**loss step_result = pd.Series( dict( loss=loss, perplexity=perplexity, time=time_model, )).add_prefix('train_') step_result['lr'] = next(iter(opt.param_groups))['lr'] if c.get('use_cache'): step_result['theta'] = from_torch(preds['theta']) step_result['lambda'] = from_torch(preds['lambda']) s.step = step = step + 1 if step % c.step_eval == 0: step_result = step_result.append( pd.Series(evaluate(c, data_val, net)).add_prefix('val_')) s.record_step = step_result['val_loss'] < best_val_loss clear_gpu_memory() s.step_result = step_result c.on_step_end(s) except Exception as e: import traceback err = traceback.format_exc() if c.main: c.log(err) else: print(err) finally: c.on_train_end(s)
net, device_ids=[args.local_rank], output_device=args.local_rank, ) logging.info("Number of GPUs: {}, using DistributedDaraParallel.".format( args.num_gpus)) ##################### Loss function and optimizer ############################ criterion_eval = get_criterion(cfg, train=False) criterion_eval.cuda() optimizer = None scheduler = None if not cfg.EVALUATE: criterion = get_criterion(cfg) criterion.cuda() optimizer = get_opt(cfg, net, resume=iteration > 0) scheduler = get_lr_scheduler(cfg, optimizer, last_iter=iteration) ##################### make a checkpoint ############################ best_acc = 0.0 checkpointer = Checkpointer(net, cfg.MODEL.ARCH, best_acc=best_acc, optimizer=optimizer, scheduler=scheduler, save_dir=cfg.OUTPUT_DIR, is_test=cfg.EVALUATE, only_save_last=cfg.ONLY_SAVE_LAST) filepath = cfg.MODEL.MODEL_PATH if not os.path.isfile(filepath):
def train(c): import distiller net = Transformer(c) opt = get_opt(c, net) net, opt, step = c.init_model(net, opt=opt, step='max', train=True) step_lr = scheduler(c, opt, step) data_tr = SampleIterator(c, c.train_batch, split='valid' if c.debug else 'train') iter_tr = iter(data_tr) data_val = SequentialIterator(c, c.eval_batch, split='valid') data_test = SequentialIterator(c, c.eval_batch, split='test') print('Before quantization') tbl, sparsity = distiller.weights_sparsity_tbl_summary( net, return_total_sparsity=True) step_result = pd.Series(evaluate(c, data_val, net)).add_prefix('val_') step_result = step_result.append( pd.Series(evaluate(c, data_test, net)).add_prefix('test_')) step_result['sparsity'] = sparsity print(step_result) compression_scheduler = distiller.config.file_config(net, opt, c.compress) print('After initial quantization') s = Namespace(net=net, opt=opt, step=step) c.on_train_start(s) tbl, sparsity = distiller.weights_sparsity_tbl_summary( net, return_total_sparsity=True) step_result = pd.Series(evaluate(c, data_val, net)).add_prefix('val_') step_result = step_result.append( pd.Series(evaluate(c, data_test, net)).add_prefix('test_')) step_result['sparsity'] = sparsity print(step_result) npm = [] for name, param in net.named_parameters(): if param.dim() in [2, 4] and any(type in name for type in ['weight', 'bias']): npm.append((name, param, param.abs() == 0)) best_val_loss = np.inf if s.results is not None and 'val_loss' in s.results.columns: best_val_loss = s.results['val_loss'].dropna().max() try: steps_per_epoch = c.step_eval while step < s.step_max: epoch = step // steps_per_epoch batch = step % steps_per_epoch if batch == 0: compression_scheduler.on_epoch_begin(epoch) compression_scheduler.on_minibatch_begin(epoch, batch, steps_per_epoch) step_lr(step) x = to_torch(next(iter_tr), c.device).t() t_s = time() inputs, labels = x[:-1], x[1:] preds = net(inputs, labels) loss = preds['loss'] compression_scheduler.before_backward_pass(epoch, batch, steps_per_epoch, loss, False) opt.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(net.parameters(), c.get('clip_grad', 0.5)) compression_scheduler.before_parameter_optimization( epoch, batch, steps_per_epoch, opt) opt.step() for name, param, mask in npm: param.data[mask] = 0 compression_scheduler.on_minibatch_end(epoch, batch, steps_per_epoch) if (batch + 1) == steps_per_epoch: compression_scheduler.on_epoch_end(epoch) time_model = np.round(time() - t_s, 5) loss = from_torch(loss) perplexity = np.nan if loss > 5 else np.e**loss step_result = pd.Series( dict( loss=loss, perplexity=perplexity, time=time_model, )).add_prefix('train_') step_result['lr'] = next(iter(opt.param_groups))['lr'] s.step = step = step + 1 if step % c.step_eval == 0: tbl, sparsity = distiller.weights_sparsity_tbl_summary( net, return_total_sparsity=True) step_result = step_result.append( pd.Series(evaluate(c, data_val, net)).add_prefix('val_')) step_result = step_result.append( pd.Series(evaluate(c, data_test, net)).add_prefix('test_')) step_result['sparsity'] = sparsity s.record_step = step_result['val_loss'] < best_val_loss clear_gpu_memory() s.step_result = step_result c.on_step_end(s) except Exception as e: import traceback err = traceback.format_exc() if c.main: c.log(err) else: print(err) finally: c.on_train_end(s) return net, step
def gen_soft_labels(c): c.setdefault(hebbian=False, distributed=False) net = get_net(c) opt = get_opt(c, net) net, opt, step = c.init_model(net, opt=opt, step='max', train=True) print('generating soft labels...') data_gen_tr = SequentialIteratorGenSoft(c, c.get('gen_soft_batch'), split='train') # data_gen_tr = iter(data_gen_tr) clear_gpu_memory() net.eval() with torch.no_grad(): i = 0 for batch in tqdm(data_gen_tr): x = to_torch(batch, c.device).t() # print(x.size()) # print(x[0:20]) inputs, labels = x[:-1], x[1:] probs, _ = net(inputs, labels) # loss_hard = -torch.log(probs.gather(1, labels).squeeze(1)).mean() values, indices = torch.topk(probs, c.get('topk'), dim=1) indices_ = indices.cpu().numpy() values_ = values.cpu().numpy() labels_ = labels.cpu().numpy() # print(indices_[0:5]) # print(labels_[0:5]) # exit(0) if probs.size(0) != inputs.size(0): indices_ = indices_[-inputs.size(0):, :] values_ = values_[-inputs.size(0):, :] # labels_ = labels_[-inputs.size(0):, :] if i == 0: all_soft_indices = indices_ all_soft_values = values_ else: all_soft_indices = np.concatenate((all_soft_indices, indices_), axis=0) all_soft_values = np.concatenate((all_soft_values, values_), axis=0) # print(all_soft_indices.shape) # print(all_soft_values.shape) i += 1 # if i > 100: # break all_soft_indices = np.concatenate( (all_soft_indices[0:1, :], all_soft_indices), axis=0) all_soft_values = np.concatenate( (all_soft_values[0:1, :], all_soft_values), axis=0) np.save( c.get('file_out_path') + 'all_soft_indices' + str(c.get('worker')) + '.npy', all_soft_indices) np.save( c.get('file_out_path') + 'all_soft_values' + str(c.get('worker')) + '.npy', all_soft_values) in_indices = np.load( c.get('file_out_path') + 'all_soft_indices' + str(c.get('worker')) + '.npy') cnt = 0. # print(in_indices.shape) # print(len(data.tokens)) for k in range(len(data_gen_tr.tokens)): # print(data.tokens[k]) # print(in_indices[k]) if data_gen_tr.tokens[k] in in_indices[k]: cnt += 1 print(cnt / len(data_gen_tr.tokens))
def train(c, net, compression_scheduler=None): import distiller.apputils as apputils from distiller.data_loggers import TensorBoardLogger, PythonLogger msglogger = apputils.config_pylogger('logging.conf', None) tflogger = TensorBoardLogger(msglogger.logdir) tflogger.log_gradients = True pylogger = PythonLogger(msglogger) c.setdefault(hebbian=False) emb_params = count_params(net.embed) + count_params(net.loss.projections) + count_params(net.loss.clusters) opt = get_opt(c, net) net, opt, step = c.init_model(net, opt=opt, step='max', train=True) step_lr = scheduler(c, opt, step) data_tr = SampleIterator(c, c.train_batch, split='valid' if c.debug else 'train') iter_tr = iter(data_tr) data_val = SequentialIterator(c, c.eval_batch, split='valid') s = Namespace(net=net, opt=opt, step=step) c.on_train_start(s) c.log('Embedding has %s parameters' % emb_params) if c.get("steps_per_epoch"): steps_per_epoch = c.steps_per_epoch else: steps_per_epoch = len(data_tr.tokens) // data_tr.bs // c.train_chunk print("#### steps per epoch %d ####" % steps_per_epoch) if c.hebbian: counters = [torch.ones(end - start, dtype=torch.long, device=c.device) for start, end in zip([0] + c.cutoffs, c.cutoffs + [c.n_vocab])] temp_counters = [torch.zeros_like(x) for x in counters] best_val_loss = np.inf if s.results is not None and 'val_loss' in s.results.columns: best_val_loss = s.results['val_loss'].dropna().max() try: while step < s.step_max: batch = step % steps_per_epoch epoch = step // steps_per_epoch if step % steps_per_epoch == 0: c.log("====> batch=%d, epoch=%d, step=%d" % (batch, epoch, step)) if compression_scheduler: compression_scheduler.on_epoch_begin(epoch) if compression_scheduler: compression_scheduler.on_minibatch_begin(epoch, minibatch_id=batch, minibatches_per_epoch=steps_per_epoch) step_lr(step) x = to_torch(next(iter_tr), c.device).t() t_s = time() inputs, labels = x[:-1], x[1:] preds = net(inputs, labels) loss = preds['loss'] if compression_scheduler: _ = compression_scheduler.before_backward_pass(epoch, minibatch_id=batch, minibatches_per_epoch=steps_per_epoch, loss=loss, return_loss_components=False) opt.zero_grad() if torch.isnan(loss): raise RuntimeError('Encountered nan loss during training') loss.backward() torch.nn.utils.clip_grad_norm_(net.parameters(), c.get('clip_grad', 0.5)) opt.step() if c.hebbian: hebbian_weight_update(c, net, preds['hiddens'], counters, temp_counters) time_model = np.round(time() - t_s, 5) loss = from_torch(loss) perplexity = np.nan if loss > 5 else np.e ** loss step_result = pd.Series(dict( loss=loss, perplexity=perplexity, time=time_model )).add_prefix('train_') step_result['lr'] = next(iter(opt.param_groups))['lr'] if c.use_cache: step_result['theta'] = preds['theta'] step_result['lambda'] = preds['lambda'].item() if compression_scheduler: compression_scheduler.on_minibatch_end(epoch, minibatch_id=batch, minibatches_per_epoch=steps_per_epoch) if step % steps_per_epoch == 0: if compression_scheduler: compression_scheduler.on_epoch_end(epoch) s.step = step = step + 1 if step % c.step_eval == 0: distiller.log_weights_sparsity(net, epoch, loggers=[tflogger, pylogger]) t, total = distiller.weights_sparsity_tbl_summary(net, return_total_sparsity=True) c.log("total sparsity: %.3lf" % total) step_result = step_result.append( pd.Series(evaluate(c, data_val, net)).add_prefix('val_') ) s.record_step = step_result['val_loss'] < best_val_loss clear_gpu_memory() s.step_result = step_result c.on_step_end(s) except Exception as e: import traceback err = traceback.format_exc() if c.main: c.log(err) else: print(err) finally: c.on_train_end(s)
distiller.model_summary(net, "sparsity", 'wikitext-103') eval_output = evaluate(config.var(use_cache=True, n_cache=2000, cache_thetaa=best_theta, cache_lambda=best_lambda), data_test, net) config.log("VAL RESULT: ppl(%.3lf) theta(%.3lf) lambda(%.3lf)" % (best_ppl, best_theta, best_lambda)) config.log("TEST RESULT: %s" % eval_output) return eval_output if __name__ == '__main__': config = Config.from_args() print("config=", config) net = get_net(config) if config.get("summary"): opt = get_opt(config, net) net, opt, step = config.init_model(net, opt=opt, step='max', train=True) config.log("===> summary of model @ step %d" % step) distiller.model_summary(net, config.summary, 'wikitext-103') exit(0) if config.get("compress"): config.log("===> compress from: %s" % config.compress) compression_scheduler = distiller.config.file_config(net, None, config.compress) if config.get('eval_cache_search'): evaluate_cache_search(config, net) else: train(config, net, compression_scheduler)