def evaluate(c, data, net): clear_gpu_memory() was_training = net.training net.eval() t_s = time() with torch.no_grad(): weights = [] losses = [] prevs = None for batch in data: x = to_torch(batch, c.device).t() inputs, labels = x[:-1], x[1:] preds = net.forward(inputs, labels, prevs=prevs) prevs = preds['state'] losses.append(preds['loss']) weights.append(labels.size(0)) weights = np.array(weights) weights = weights / weights.sum() loss = sum(x * w for x, w in zip(losses, weights)) if c.distributed: gathered_losses = [torch.zeros_like(loss) for _ in range(c.world_size)] torch.distributed.all_gather(gathered_losses, loss) loss = sum(gathered_losses) / len(gathered_losses) if was_training: net.train() loss = from_torch(loss) perplexity = np.nan if loss > 5 else np.e ** loss return dict(loss=loss, perplexity=perplexity, time=np.round(time() - t_s))
def profile(c): c.setdefault(hebbian=False, distributed=False) net, step = config.var(device="cpu").load_model("max") print('profiling') net.eval() data_val = SequentialIterator(config, config.eval_batch, split="test") with torch.no_grad(): for batch in data_val: x = to_torch(batch, c.device).t() inputs, labels = x[:-1], x[1:] macs = profile_macs(net, (inputs, labels)) print("==> FLOPS: ", macs / config.eval_chunk * 2) print("==> Models size: ", count_params(net)) exit(0)
def gen_soft_labels(c): c.setdefault(hebbian=False, distributed=False) net = get_net(c) opt = get_opt(c, net) net, opt, step = c.init_model(net, opt=opt, step='max', train=True) print('generating soft labels...') data_gen_tr = SequentialIteratorGenSoft(c, c.get('gen_soft_batch'), split='train') # data_gen_tr = iter(data_gen_tr) clear_gpu_memory() net.eval() with torch.no_grad(): i = 0 for batch in tqdm(data_gen_tr): x = to_torch(batch, c.device).t() # print(x.size()) # print(x[0:20]) inputs, labels = x[:-1], x[1:] probs, _ = net(inputs, labels) # loss_hard = -torch.log(probs.gather(1, labels).squeeze(1)).mean() values, indices = torch.topk(probs, c.get('topk'), dim=1) indices_ = indices.cpu().numpy() values_ = values.cpu().numpy() labels_ = labels.cpu().numpy() # print(indices_[0:5]) # print(labels_[0:5]) # exit(0) if probs.size(0) != inputs.size(0): indices_ = indices_[-inputs.size(0):, :] values_ = values_[-inputs.size(0):, :] # labels_ = labels_[-inputs.size(0):, :] if i == 0: all_soft_indices = indices_ all_soft_values = values_ else: all_soft_indices = np.concatenate((all_soft_indices, indices_), axis=0) all_soft_values = np.concatenate((all_soft_values, values_), axis=0) # print(all_soft_indices.shape) # print(all_soft_values.shape) i += 1 # if i > 100: # break all_soft_indices = np.concatenate( (all_soft_indices[0:1, :], all_soft_indices), axis=0) all_soft_values = np.concatenate( (all_soft_values[0:1, :], all_soft_values), axis=0) np.save( c.get('file_out_path') + 'all_soft_indices' + str(c.get('worker')) + '.npy', all_soft_indices) np.save( c.get('file_out_path') + 'all_soft_values' + str(c.get('worker')) + '.npy', all_soft_values) in_indices = np.load( c.get('file_out_path') + 'all_soft_indices' + str(c.get('worker')) + '.npy') cnt = 0. # print(in_indices.shape) # print(len(data.tokens)) for k in range(len(data_gen_tr.tokens)): # print(data.tokens[k]) # print(in_indices[k]) if data_gen_tr.tokens[k] in in_indices[k]: cnt += 1 print(cnt / len(data_gen_tr.tokens))
def train(c): c.setdefault(hebbian=False) net = get_net(c) emb_params = count_params(net.embed) + count_params(net.loss.projections) + count_params(net.loss.clusters) opt = get_opt(c, net) net, opt, step = c.init_model(net, opt=opt, step='max', train=True) if c.get('distillation_teacher') == 'file': data_tr_distill = DistillationSampleIterator(c, c.train_batch, split='train') iter_tr_distill = iter(data_tr_distill) step_lr = scheduler(c, opt, step) data_tr = SampleIterator(c, c.train_batch, split='valid' if c.debug else 'train') iter_tr = iter(data_tr) data_val = SequentialIterator(c, c.eval_batch, split='valid') s = Namespace(net=net, opt=opt, step=step) c.on_train_start(s) c.log('Embedding has %s parameters' % emb_params) if c.hebbian: counters = [torch.ones(end - start, dtype=torch.long, device=c.device) for start, end in zip([0] + c.cutoffs, c.cutoffs + [c.n_vocab])] temp_counters = [torch.zeros_like(x) for x in counters] best_val_loss = np.inf if s.results is not None and 'val_loss' in s.results.columns: best_val_loss = s.results['val_loss'].dropna().max() try: while step < s.step_max: step_lr(step) if c.get('distillation_teacher') == 'file': x_hard_labels, x_soft_labels, x_soft_probs = next(iter_tr_distill) x_hard_labels = to_torch(x_hard_labels, c.device).t() x_soft_labels = to_torch(x_soft_labels, c.device) x_soft_labels = x_soft_labels.permute(1, 0, 2) x_soft_probs = to_torch(x_soft_probs, c.device) x_soft_probs = x_soft_probs.permute(1, 0, 2) inputs, hard_labels = x_hard_labels[:-1], x_hard_labels[1:] soft_labels = x_soft_labels[1:] soft_probs = x_soft_probs[1:] t_s = time() preds = net(inputs=inputs, labels=hard_labels, soft_labels=soft_labels, soft_probs=soft_probs, is_distilling=True, current_step=step) loss = preds['loss'] total_loss = loss extras = {} else: x = to_torch(next(iter_tr), c.device).t() t_s = time() inputs, labels = x[:-1], x[1:] preds = net(inputs, labels) loss = preds['loss'] if c.model_class == 'UniversalTransformer': act_loss = preds['act_loss'] total_loss = act_loss + loss extras = dict(act_loss=from_torch(act_loss), n_updates=from_torch(preds['n_updates'].mean())) else: total_loss = loss extras = {} opt.zero_grad() if torch.isnan(total_loss): raise RuntimeError('Encountered nan loss during training') with amp.scale_loss(total_loss, opt) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_(net.parameters(), c.get('clip_grad', 0.5)) opt.step() if c.hebbian: hebbian_weight_update(c, net, preds['hiddens'], counters, temp_counters) time_model = np.round(time() - t_s, 5) loss = from_torch(loss) perplexity = np.nan if loss > 5 else np.e ** loss step_result = pd.Series(Dict( loss=loss, perplexity=perplexity, time=time_model, **extras )).add_prefix('train_') step_result['lr'] = next(iter(opt.param_groups))['lr'] step_result['theta'] = from_torch(preds['theta']) step_result['lambda'] = from_torch(preds['lambda']) s.step = step = step + 1 if step % c.step_eval == 0: step_result = step_result.append( pd.Series(evaluate(c, data_val, net)).add_prefix('val_') ) s.record_step = step_result['val_loss'] < best_val_loss clear_gpu_memory() s.step_result = step_result c.on_step_end(s) except Exception as e: import traceback err = traceback.format_exc() if c.main: c.log(err) else: print(err) finally: c.on_train_end(s)
def train(c, net, compression_scheduler=None): c.setdefault(hebbian=False) assert not c.distributed and not c.parallel emb_params = count_params(net.embed) + count_params(net.loss.projections) + count_params(net.loss.clusters) opt = get_opt(c, net) net, opt, step = c.init_model(net, opt=opt, step='max', train=True) step_lr = scheduler(c, opt, step) data_tr = SampleIterator(c, c.train_batch, split='valid' if c.debug else 'train') iter_tr = iter(data_tr) data_val = SequentialIterator(c, c.eval_batch, split='valid') s = Namespace(net=net, opt=opt, step=step) c.on_train_start(s) c.log('Embedding has %s parameters' % emb_params) if c.get("steps_per_epoch"): steps_per_epoch = c.steps_per_epoch else: steps_per_epoch = len(data_tr.tokens) // data_tr.bs // c.train_chunk print("#### steps per epoch %d ####" % steps_per_epoch) if c.hebbian: counters = [torch.ones(end - start, dtype=torch.long, device=c.device) for start, end in zip([0] + c.cutoffs, c.cutoffs + [c.n_vocab])] temp_counters = [torch.zeros_like(x) for x in counters] best_val_loss = np.inf if s.results is not None and 'val_loss' in s.results.columns: best_val_loss = s.results['val_loss'].dropna().max() try: while step < s.step_max: batch = step % steps_per_epoch epoch = step // steps_per_epoch if step % steps_per_epoch == 0: c.log("====> batch=%d, epoch=%d, step=%d" % (batch, epoch, step)) if compression_scheduler: compression_scheduler.on_epoch_begin(epoch) if compression_scheduler: compression_scheduler.on_minibatch_begin(epoch, minibatch_id=batch, minibatches_per_epoch=steps_per_epoch) step_lr(step) x = to_torch(next(iter_tr), c.device).t() t_s = time() inputs, labels = x[:-1], x[1:] preds = net(inputs, labels) loss = preds['loss'] if c.model_class == 'UniversalTransformer': act_loss = preds['act_loss'] total_loss = act_loss + loss extras = dict(act_loss=from_torch(act_loss), n_updates=from_torch(preds['n_updates'].mean())) else: total_loss = loss extras = {} if compression_scheduler: _ = compression_scheduler.before_backward_pass(epoch, minibatch_id=batch, minibatches_per_epoch=steps_per_epoch, loss=total_loss, return_loss_components=False) opt.zero_grad() if torch.isnan(total_loss): raise RuntimeError('Encountered nan loss during training') with amp.scale_loss(total_loss, opt) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_(net.parameters(), c.get('clip_grad', 0.5)) opt.step() if c.hebbian: hebbian_weight_update(c, net, preds['hiddens'], counters, temp_counters) time_model = np.round(time() - t_s, 5) loss = from_torch(loss) perplexity = np.nan if loss > 5 else np.e ** loss step_result = pd.Series(Dict( loss=loss, perplexity=perplexity, time=time_model, **extras )).add_prefix('train_') step_result['lr'] = next(iter(opt.param_groups))['lr'] step_result['theta'] = preds['theta'] step_result['lambda'] = preds['lambda'].item() if compression_scheduler: compression_scheduler.on_minibatch_end(epoch, minibatch_id=batch, minibatches_per_epoch=steps_per_epoch) if step % steps_per_epoch == 0: if compression_scheduler: compression_scheduler.on_epoch_end(epoch) if step % c.step_eval == 0: distiller.log_weights_sparsity(net, epoch, loggers=[tflogger, pylogger]) t, total = distiller.weights_sparsity_tbl_summary(net, return_total_sparsity=True) c.log("total sparsity: %.3lf" % total) step_result = step_result.append( pd.Series(evaluate(c, data_val, net)).add_prefix('val_') ) s.record_step = step_result['val_loss'] < best_val_loss clear_gpu_memory() s.step_result = step_result """ # sanity check t, total = distiller.weights_sparsity_tbl_summary(net, return_total_sparsity=True) val_result = evaluate(c, data_val, net) c.log("#### before on step end: sparsity %.3lf | val result: %s" % (total, val_result)) """ c.on_step_end(s) """ # sanity check t, total = distiller.weights_sparsity_tbl_summary(net, return_total_sparsity=True) val_result = evaluate(c, data_val, net) c.log("#### after on step end: sparsity %.3lf | val result: %s" % (total, val_result)) """ c.log("@@@@ step %d end @@@@" % step) s.step = step = step + 1 except Exception as e: import traceback err = traceback.format_exc() if c.main: c.log(err) else: print(err) finally: c.on_train_end(s)