def plot_histograms(model, x, y, epoch, out_dir): """ Plots some histograms :param model: a C-GAE instance :param batch: the batch to use for computing histograms :param epoch: epoch nr :param out_dir: directory where to save the histograms """ plot_hist(to_numpy(x), f"data_ep{epoch}", os.path.join(out_dir, f"hist_data_ep{epoch}.png")) mapping = to_numpy(model.mapping_(x, y)) plot_hist(mapping, f"mapping_ep{epoch}", os.path.join(out_dir, f"hist_map_ep{epoch}.png")) output = to_numpy(model(model.mapping_(x, y), y)) plot_hist(output, f"recon_ep{epoch}", os.path.join(out_dir, f"hist_recon_ep{epoch}.png")) weight_x = to_numpy(model.conv_x.weight) plot_hist(weight_x, f"weightx_ep{epoch}", os.path.join(out_dir, f"hist_weightx_ep{epoch}.png")) weight_y = to_numpy(model.conv_y.weight) plot_hist(weight_y, f"weighty_ep{epoch}", os.path.join(out_dir, f"hist_weighty_ep{epoch}.png"))
def plot_kernels(model, epoch, out_dir): """ Plots the input weights of the C-GAE :param model: a C-GAE instance :param epoch: epoch nr (int) :param out_dir: directory where to save the plot """ filter_x = to_numpy(model.conv_x.weight) make_tiles(filter_x, os.path.join(out_dir, f"filtersx_ep{epoch}.png")) filter_y = to_numpy(model.conv_y.weight) make_tiles(filter_y, os.path.join(out_dir, f"filtersy_ep{epoch}.png"))
def evaluate(model, config, valid_loader, eval_device=None): opt = config['opt'] device = opt.device if eval_device != None: device = eval_device total_loss = 0. total_examples = 0 correct = 0 criterion = torch.nn.CrossEntropyLoss().to(device) preds = None ys = None with torch.no_grad(): iterator = tqdm(valid_loader, total=len(valid_loader), desc=f"Evaluate") for i, (x, y) in enumerate(iterator): model.eval() x = to_device(x, device) y = to_device(y, device) logits = model(x) loss = criterion(logits, y) # softmax after computing cross entropy loss logits = torch.softmax(logits, dim=-1) if preds is None: preds = to_numpy(logits) ys = to_numpy(y) else: preds = np.append(preds, to_numpy(logits), axis=0) ys = np.append(ys, to_numpy(y), axis=0) predicted = logits.argmax(1) correct += (predicted == y).sum().item() cur_examples = y.size(0) total_loss += (loss.item() * cur_examples) total_examples += cur_examples # generate report labels = config['labels'] label_names = [v for k, v in sorted(labels.items(), key=lambda x: x[0])] preds_ids = np.argmax(preds, axis=1) try: print( classification_report(ys, preds_ids, target_names=label_names, digits=4)) print(labels) print(confusion_matrix(ys, preds_ids)) except Exception as e: logger.warn(str(e)) cur_loss = total_loss / total_examples cur_acc = correct / total_examples return cur_loss, cur_acc
def evaluate_maskedLogit(data_path, seuil=None): print("loading dataset...") datas = loaders.load_from_disk(data_path) loaders.set_dataset_format(datas) print("loading model...") model = models.maskedLogit() model.eval() model.to(device) dataloader = torch.utils.data.DataLoader(datas, batch_size=16) outputs = [[], []] confusion = [[0, 0], [0, 0]] with torch.no_grad(): for batch in tqdm(dataloader, desc="evaluation"): out = model(batch['index'].to(device), batch['token'].to(device), batch['input_ids'].to(device), batch['attention_mask'].to(device)) for l, o in zip(to_numpy(batch['label']), out): outputs[l].append(o.item()) if seuil is not None: confusion[l][o < seuil] += 1 return outputs, confusion
def train(self, batch, batch_size, gamma, time_stamp): self.net.train() self.step_cnt += 1 assert (len(batch) == batch_size) if len(batch) < batch_size: return None states = np.vstack([x.state for x in batch]) actions = np.vstack([x.action for x in batch]) rewards = np.vstack([x.reward for x in batch]) next_states = np.vstack([x.next_state for x in batch]) done = np.vstack([x.done for x in batch]) actions = to_torch_var(actions).long() Q_pred = self.get_Q(states).gather(1, actions) done_mask = (~done).astype(np.float) Q_expected = np.max(to_numpy(self.get_actor_Q(next_states)), axis=1) Q_expected = np.expand_dims(Q_expected, axis=1) Q_expected = done_mask * Q_expected Q_target = rewards + gamma * Q_expected Q_target = to_torch_var(Q_target) self.optimizer.zero_grad() loss = self.lossfn(Q_pred, Q_target) loss.backward() for param in self.net.parameters(): param.grad.data.clamp_(-1, 1) self.optimizer.step() if self.step_cnt % self.update_rate == 0: self.actor_net.load_state_dict(self.net.state_dict()) return loss
def train_one_step(self, data_i, total_steps_so_far): images_minibatch = self.prepare_images(data_i) if self.toggle_training_mode() == "generator": losses = self.train_discriminator_one_step(images_minibatch) else: losses = self.train_generator_one_step(images_minibatch) return util.to_numpy(losses)
def testFit(self): self.unmixModel = AnalysisBySynthesis(paramFile=self.modelFilename, wavenumbers=self.wavenumbers, dtype=torch.float64, device=device) trueAbundances = torch.empty(len(self.unmixModel.endmemberModels), dtype=torch.float64, device=device).uniform_(0, 1) trueAbundances /= trueAbundances.sum() trueSpectra = torch.zeros(len(self.wavenumbers), dtype=torch.float64, device=device) for i, model in enumerate(self.unmixModel.endmemberModels.values()): for parameter in model.parameters(): parameter.requires_grad_(False) perturb = np.random.uniform(-.0001, 0.0001) parameter.add_(perturb) trueSpectra += model.forward()*trueAbundances[i] self.unmixModel = AnalysisBySynthesis(paramFile=self.modelFilename, wavenumbers=self.wavenumbers, dtype=torch.float64, device=device) trueSpectra = to_numpy(trueSpectra) trueSpectra = torch.from_numpy(trueSpectra).type(torch.float64).to(device) self.unmixModel.fit(trueSpectra, epochs=100, learningRate=1e-6) self.assertTrue(torch.allclose(trueSpectra, self.unmixModel.predictedSpectra, rtol=5e-2)) self.assertTrue(torch.allclose(trueAbundances, self.unmixModel.abundances, rtol=1e-1)) plt.plot(self.wavenumbers.cpu().detach().numpy(), trueSpectra.cpu().detach().numpy(), label='truth') plt.plot(self.wavenumbers.cpu().detach().numpy(), self.unmixModel.predictedSpectra.cpu().detach().numpy(), '--', label='model') plt.title('Test Fit') plt.legend() plt.show()
def run_extracter(src): indexes = src['index'].to(device) in_ids = src['input_ids'].to(device) att_masks = src['attention_mask'].to(device) with torch.no_grad(): hidden_states = extracter(indexes, in_ids, att_masks) return {'hidden_state': to_numpy(hidden_states)}
def plot_data(data, epoch, out_dir): """ Plots input data :param data: a data batch to plot :param epoch: epoch nr :param out_dir: directory where to save the plot """ make_tiles(to_numpy(data), os.path.join(out_dir, f"data_ep{epoch}.png"))
def build_onnx_input(config, ort_session, x): args = config['args'] x = to_numpy(x) if config['emb_class'] in ['glove', 'elmo']: ort_inputs = { ort_session.get_inputs()[0].name: x[0], ort_session.get_inputs()[1].name: x[1] } if args.use_char_cnn: ort_inputs[ort_session.get_inputs()[2].name] = x[2] else: # x order must be sync with x parameter of BertLSTMCRF.forward(). # x[0,1,2] : [batch_size, seq_size], input_ids / input_mask / segment_ids == input_ids / attention_mask / token_type_ids # x[3] : [batch_size, seq_size], pos_ids # x[4] : [batch_size, seq_size, char_n_ctx], char_ids # with --bert_use_doc_context # x[5] : [batch_size, seq_size], doc2sent_idx # x[6] : [batch_size, seq_size], doc2sent_mask # x[7] : [batch_size, seq_size], word2token_idx with --bert_use_subword_pooling # x[8] : [batch_size, seq_size], word2token_mask with --bert_use_subword_pooling # x[9] : [batch_size, seq_size], word_ids with --bert_use_word_embedding # without --bert_use_doc_context # x[5] : [batch_size, seq_size], word2token_idx with --bert_use_subword_pooling # x[6] : [batch_size, seq_size], word2token_mask with --bert_use_subword_pooling # x[7] : [batch_size, seq_size], word_ids with --bert_use_word_embedding if config['emb_class'] in ['roberta', 'distilbert', 'bart']: ort_inputs = { ort_session.get_inputs()[0].name: x[0], ort_session.get_inputs()[1].name: x[1] } else: ort_inputs = { ort_session.get_inputs()[0].name: x[0], ort_session.get_inputs()[1].name: x[1], ort_session.get_inputs()[2].name: x[2] } if args.bert_use_pos: ort_inputs[ort_session.get_inputs()[3].name] = x[3] if args.use_char_cnn: ort_inputs[ort_session.get_inputs()[4].name] = x[4] base_idx = 5 if args.bert_use_doc_context: ort_inputs[ort_session.get_inputs()[base_idx].name] = x[base_idx] ort_inputs[ort_session.get_inputs()[base_idx + 1].name] = x[base_idx + 1] base_idx += 2 if args.bert_use_subword_pooling: ort_inputs[ort_session.get_inputs()[base_idx].name] = x[base_idx] ort_inputs[ort_session.get_inputs()[base_idx + 1].name] = x[base_idx + 1] if args.bert_use_word_embedding: ort_inputs[ort_session.get_inputs()[base_idx + 2].name] = x[base_idx + 2] return ort_inputs
def set_to_norm(self, val): """ Sets the norms of all convolutional kernels of the C-GAE to a specific value. :param val: norms of kernels are set to this value """ shape_x = self.conv_x.weight.size() conv_x_reshape = self.conv_x.weight.view(shape_x[0], -1) norms_x = ((conv_x_reshape**2).sum(1)**.5).view(-1, 1) conv_x_reshape = conv_x_reshape / norms_x weight_x_new = to_numpy(conv_x_reshape.view(*shape_x)) * val self.conv_x.weight.data = cuda_tensor(weight_x_new) shape_y = self.conv_y.weight.size() conv_y_reshape = self.conv_y.weight.view(shape_y[0], -1) norms_y = ((conv_y_reshape**2).sum(1)**.5).view(-1, 1) conv_y_reshape = conv_y_reshape / norms_y weight_y_new = to_numpy(conv_y_reshape.view(*shape_y)) * val self.conv_y.weight.data = cuda_tensor(weight_y_new)
def plot_recon(model, x, y, epoch, out_dir): """ Plots the reconstruction of an input batch :param model: a C-GAE instance :param batch: the batch to reconstruct :param epoch: epoch nr :param out_dir: directory where to save the plot """ output = to_numpy(model(model.mapping_(x, y), y)) make_tiles(output, os.path.join(out_dir, f"recon_ep{epoch}.png"))
def evaluate_classifier(datas, model): dataloader = torch.utils.data.DataLoader(datas, batch_size=8) confusion = [[0, 0], [0, 0]] loss = 0 num_sample = 0 criterion = torch.nn.BCEWithLogitsLoss(reduction='sum') with torch.no_grad(): for batch in tqdm(dataloader, desc="evaluation", leave=None): pred = model(batch['token'].to(device), batch['hidden_state'].to(torch.float).to(device)) loss += criterion(pred, batch['label'].to( torch.float).to(device)).item() num_sample += len(batch['token']) pred = to_numpy(torch.sigmoid(pred)) for p, l in zip(pred > 0.5, to_numpy(batch['label'])): confusion[l][p] += 1 acc = confusion[0][0] + confusion[1][1] return loss / num_sample, acc / num_sample, confusion
def inference(self, data): config = self.config opt = config['opt'] model = self.model labels = self.labels logits = model(data) predicted = logits.argmax(1) predicted = to_numpy(predicted)[0] predicted_raw = labels[predicted] logger.info("[Model predicted] %s", predicted_raw) return predicted_raw
def evaluate(model, config, val_loader): model.eval() opt = config['opt'] pad_label_id = config['pad_label_id'] eval_loss = 0. criterion = nn.CrossEntropyLoss(ignore_index=pad_label_id).to(opt.device) n_batches = len(val_loader) prog = Progbar(target=n_batches) preds = None ys = None with torch.no_grad(): for i, (x, y) in enumerate(val_loader): x = to_device(x, opt.device) y = to_device(y, opt.device) if opt.use_crf: logits, prediction = model(x) mask = torch.sign(torch.abs(x[0])).to(torch.uint8).to( opt.device) log_likelihood = model.crf(logits, y, mask=mask, reduction='mean') loss = -1 * log_likelihood else: logits = model(x) loss = criterion(logits.view(-1, model.label_size), y.view(-1)) if preds is None: if opt.use_crf: preds = to_numpy(prediction) else: preds = to_numpy(logits) ys = to_numpy(y) else: if opt.use_crf: preds = np.append(preds, to_numpy(prediction), axis=0) else: preds = np.append(preds, to_numpy(logits), axis=0) ys = np.append(ys, to_numpy(y), axis=0) eval_loss += loss.item() prog.update(i + 1, [('eval curr loss', loss.item())]) eval_loss = eval_loss / n_batches if not opt.use_crf: preds = np.argmax(preds, axis=2) # compute measure using seqeval labels = model.labels ys_lbs = [[] for _ in range(ys.shape[0])] preds_lbs = [[] for _ in range(ys.shape[0])] for i in range(ys.shape[0]): # foreach sentence for j in range(ys.shape[1]): # foreach token if ys[i][j] != pad_label_id: ys_lbs[i].append(labels[ys[i][j]]) preds_lbs[i].append(labels[preds[i][j]]) ret = { "loss": eval_loss, "precision": precision_score(ys_lbs, preds_lbs), "recall": recall_score(ys_lbs, preds_lbs), "f1": f1_score(ys_lbs, preds_lbs), "report": classification_report(ys_lbs, preds_lbs, digits=4), } print(ret['report']) return ret
def plot_mapping(model, x, y, epoch, out_dir): """ Plots the top most mapping layer given an input batch :param model: a C-GAE instance :param batch: the batch to reconstruct :param epoch: epoch nr :param out_dir: directory where to save the plot """ # x = cuda_variable(batch[0][None,:,:,:]) # y = cuda_variable(batch[1][None,:,:,:]) output = np.transpose(to_numpy(model.mapping_(x, y)), axes=(0, 3, 2, 1)) make_tiles(output, os.path.join(out_dir, f"mapping_ep{epoch}.png"))
def atest_pinv(): a = torch.tensor([[2., 7, 9], [1, 9, 8], [2, 7, 5]]) b = torch.tensor([[6., 6, 1], [10, 7, 7], [7, 10, 10]]) C = u.Kron(a, b) u.check_close(a.flatten().norm() * b.flatten().norm(), C.frobenius_norm()) u.check_close(C.frobenius_norm(), 4 * math.sqrt(11635.)) Ci = [[ 0, 5 / 102, -(7 / 204), 0, -(70 / 561), 49 / 561, 0, 125 / 1122, -(175 / 2244) ], [ 1 / 20, -(53 / 1020), 8 / 255, -(7 / 55), 371 / 2805, -(224 / 2805), 5 / 44, -(265 / 2244), 40 / 561 ], [ -(1 / 20), 3 / 170, 3 / 170, 7 / 55, -(42 / 935), -(42 / 935), -(5 / 44), 15 / 374, 15 / 374 ], [ 0, -(5 / 102), 7 / 204, 0, 20 / 561, -(14 / 561), 0, 35 / 1122, -(49 / 2244) ], [ -(1 / 20), 53 / 1020, -(8 / 255), 2 / 55, -(106 / 2805), 64 / 2805, 7 / 220, -(371 / 11220), 56 / 2805 ], [ 1 / 20, -(3 / 170), -(3 / 170), -(2 / 55), 12 / 935, 12 / 935, -(7 / 220), 21 / 1870, 21 / 1870 ], [0, 5 / 102, -(7 / 204), 0, 0, 0, 0, -(5 / 102), 7 / 204], [ 1 / 20, -(53 / 1020), 8 / 255, 0, 0, 0, -(1 / 20), 53 / 1020, -(8 / 255) ], [ -(1 / 20), 3 / 170, 3 / 170, 0, 0, 0, 1 / 20, -(3 / 170), -(3 / 170) ]] C = C.expand_vec() C0 = u.to_numpy(C) Ci = torch.tensor(Ci) u.check_close(C @ Ci @ C, C) u.check_close(linalg.pinv(C0), Ci, rtol=1e-5, atol=1e-6) u.check_close(torch.pinverse(C), Ci, rtol=1e-5, atol=1e-6) u.check_close(u.pinv(C), Ci, rtol=1e-5, atol=1e-6) u.check_close(C.pinv(), Ci, rtol=1e-5, atol=1e-6)
def build_onnx_input(config, ort_session, x): args = config['args'] x = to_numpy(x) if config['emb_class'] == 'glove': ort_inputs = {ort_session.get_inputs()[0].name: x} else: if config['emb_class'] in ['roberta', 'distilbert', 'bart']: ort_inputs = { ort_session.get_inputs()[0].name: x[0], ort_session.get_inputs()[1].name: x[1] } else: ort_inputs = { ort_session.get_inputs()[0].name: x[0], ort_session.get_inputs()[1].name: x[1], ort_session.get_inputs()[2].name: x[2] } return ort_inputs
def inference(opt): # set config config = load_config(opt) if opt.num_threads > 0: torch.set_num_threads(opt.num_threads) config['opt'] = opt # set path: opt.embedding_path, opt.vocab_path, opt.label_path set_path(config) # load pytorch model checkpoint checkpoint = load_checkpoint(opt.model_path, device=opt.device) # prepare model and load parameters model = load_model(config, checkpoint) model.eval() # load onnx model for using onnxruntime if opt.enable_ort: import onnxruntime as ort sess_options = ort.SessionOptions() sess_options.inter_op_num_threads = opt.num_threads sess_options.intra_op_num_threads = opt.num_threads ort_session = ort.InferenceSession(opt.onnx_path, sess_options=sess_options) # enable to use dynamic quantized model (pytorch>=1.3.0) if opt.enable_dqm and opt.device == 'cpu': model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) print(model) # prepare tokenizer tokenizer = prepare_tokenizer(config, model) # prepare labels labels = config['labels'] # inference f_out = open(opt.test_path + '.inference', 'w', encoding='utf-8') total_examples = 0 total_duration_time = 0.0 with torch.no_grad(), open(opt.test_path, 'r', encoding='utf-8') as f: for i, line in enumerate(f): start_time = time.time() sent, label = line.strip().split('\t') x_raw = sent.split() y_raw = label text = ' '.join(x_raw) x = encode_text(config, tokenizer, text) x = to_device(x, opt.device) if opt.enable_ort: x = to_numpy(x) if config['emb_class'] == 'glove': ort_inputs = {ort_session.get_inputs()[0].name: x} else: if config['emb_class'] in [ 'roberta', 'distilbert', 'bart' ]: ort_inputs = { ort_session.get_inputs()[0].name: x[0], ort_session.get_inputs()[1].name: x[1] } else: ort_inputs = { ort_session.get_inputs()[0].name: x[0], ort_session.get_inputs()[1].name: x[1], ort_session.get_inputs()[2].name: x[2] } logits = ort_session.run(None, ort_inputs)[0] logits = to_device(torch.tensor(logits), opt.device) else: logits = model(x) predicted = logits.argmax(1) predicted = to_numpy(predicted)[0] predicted_raw = labels[predicted] f_out.write(text + '\t' + y_raw + '\t' + predicted_raw + '\n') total_examples += 1 if opt.num_examples != 0 and total_examples >= opt.num_examples: logger.info("[Stop Inference] : up to the {} examples".format( total_examples)) break duration_time = float((time.time() - start_time) * 1000) if i != 0: total_duration_time += duration_time logger.info("[Elapsed Time] : {}ms".format(duration_time)) f_out.close() logger.info( "[Elapsed Time(total_duration_time, average)] : {}ms, {}ms".format( total_duration_time, total_duration_time / (total_examples - 1)))
def evaluate(opt): # set config config = load_config(opt) if opt.num_threads > 0: torch.set_num_threads(opt.num_threads) config['opt'] = opt logger.info("%s", config) # set path set_path(config) # prepare test dataset test_loader = prepare_datasets(config) # load pytorch model checkpoint checkpoint = load_checkpoint(opt.model_path, device=opt.device) # prepare model and load parameters model = load_model(config, checkpoint) model.eval() # convert to onnx if opt.convert_onnx: (x, y) = next(iter(test_loader)) x = to_device(x, opt.device) y = to_device(y, opt.device) convert_onnx(config, model, x) check_onnx(config) logger.info("[ONNX model saved] :{}".format(opt.onnx_path)) # quantize onnx if opt.quantize_onnx: quantize_onnx(opt.onnx_path, opt.quantized_onnx_path) logger.info("[Quantized ONNX model saved] : {}".format( opt.quantized_onnx_path)) return # load onnx model for using onnxruntime if opt.enable_ort: import onnxruntime as ort sess_options = ort.SessionOptions() sess_options.inter_op_num_threads = opt.num_threads sess_options.intra_op_num_threads = opt.num_threads ort_session = ort.InferenceSession(opt.onnx_path, sess_options=sess_options) # convert to tvm format if opt.convert_tvm: (x, y) = next(iter(test_loader)) x = to_device(x, opt.device) y = to_device(y, opt.device) convert_tvm(config, model, x) logger.info("[TVM model saved] : {}".format(opt.tvm_path)) return # enable to use dynamic quantized model (pytorch>=1.3.0) if opt.enable_dqm and opt.device == 'cpu': model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) print(model) # evaluation preds = None ys = None correct = 0 n_batches = len(test_loader) total_examples = 0 whole_st_time = time.time() first_time = time.time() first_examples = 0 total_duration_time = 0.0 with torch.no_grad(): for i, (x, y) in enumerate(tqdm(test_loader, total=n_batches)): start_time = time.time() x = to_device(x, opt.device) y = to_device(y, opt.device) if opt.enable_ort: x = to_numpy(x) if config['emb_class'] == 'glove': ort_inputs = {ort_session.get_inputs()[0].name: x} else: if config['emb_class'] in [ 'roberta', 'distilbert', 'bart' ]: ort_inputs = { ort_session.get_inputs()[0].name: x[0], ort_session.get_inputs()[1].name: x[1] } else: ort_inputs = { ort_session.get_inputs()[0].name: x[0], ort_session.get_inputs()[1].name: x[1], ort_session.get_inputs()[2].name: x[2] } logits = ort_session.run(None, ort_inputs)[0] logits = to_device(torch.tensor(logits), opt.device) else: logits = model(x) if preds is None: preds = to_numpy(logits) ys = to_numpy(y) else: preds = np.append(preds, to_numpy(logits), axis=0) ys = np.append(ys, to_numpy(y), axis=0) predicted = logits.argmax(1) correct += (predicted == y).sum().item() cur_examples = y.size(0) total_examples += cur_examples if i == 0: # first one may take longer time, so ignore in computing duration. first_time = float((time.time() - first_time) * 1000) first_examples = cur_examples if opt.num_examples != 0 and total_examples >= opt.num_examples: logger.info("[Stop Evaluation] : up to the {} examples".format( total_examples)) break duration_time = float((time.time() - start_time) * 1000) if i != 0: total_duration_time += duration_time ''' logger.info("[Elapsed Time] : {}ms".format(duration_time)) ''' # generate report labels = config['labels'] label_names = [v for k, v in sorted(labels.items(), key=lambda x: x[0])] preds_ids = np.argmax(preds, axis=1) try: print( classification_report(ys, preds_ids, target_names=label_names, digits=4)) print(labels) print(confusion_matrix(ys, preds_ids)) except Exception as e: logger.warn(str(e)) acc = correct / total_examples whole_time = float((time.time() - whole_st_time) * 1000) avg_time = (whole_time - first_time) / (total_examples - first_examples) # write predictions to file write_prediction(opt, preds, labels) logger.info("[Accuracy] : {:.4f}, {:5d}/{:5d}".format( acc, correct, total_examples)) logger.info("[Elapsed Time] : {}ms, {}ms on average".format( whole_time, avg_time)) logger.info( "[Elapsed Time(total_duration_time, average)] : {}ms, {}ms".format( total_duration_time, total_duration_time / (total_examples - 1)))
def evaluate(opt): # set config config = load_config(opt) if opt.num_threads > 0: torch.set_num_threads(opt.num_threads) config['opt'] = opt logger.info("%s", config) # set path set_path(config) # prepare test dataset test_loader = prepare_datasets(config) # load pytorch model checkpoint checkpoint = load_checkpoint(config) # prepare model and load parameters model = load_model(config, checkpoint) model.eval() # convert to onnx format if opt.convert_onnx: (x, y) = next(iter(test_loader)) x = to_device(x, opt.device) y = to_device(y, opt.device) convert_onnx(config, model, x) check_onnx(config) logger.info("[ONNX model saved at {}".format(opt.onnx_path)) # quantize onnx if opt.quantize_onnx: quantize_onnx(opt.onnx_path, opt.quantized_onnx_path) logger.info("[Quantized ONNX model saved at {}".format( opt.quantized_onnx_path)) return # load onnx model for using onnxruntime if opt.enable_ort: import onnxruntime as ort sess_options = ort.SessionOptions() sess_options.inter_op_num_threads = opt.num_threads sess_options.intra_op_num_threads = opt.num_threads ort_session = ort.InferenceSession(opt.onnx_path, sess_options=sess_options) # enable to use dynamic quantized model (pytorch>=1.3.0) if opt.enable_dqm and opt.device == 'cpu': model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) print(model) # evaluation preds = None ys = None n_batches = len(test_loader) total_examples = 0 whole_st_time = time.time() first_time = time.time() first_examples = 0 total_duration_time = 0.0 with torch.no_grad(): for i, (x, y) in enumerate(tqdm(test_loader, total=n_batches)): start_time = time.time() x = to_device(x, opt.device) y = to_device(y, opt.device) if opt.enable_ort: x = to_numpy(x) if config['emb_class'] == 'glove': ort_inputs = { ort_session.get_inputs()[0].name: x[0], ort_session.get_inputs()[1].name: x[1] } if opt.use_char_cnn: ort_inputs[ort_session.get_inputs()[2].name] = x[2] if config['emb_class'] in [ 'bert', 'distilbert', 'albert', 'roberta', 'bart', 'electra' ]: if config['emb_class'] in ['distilbert', 'bart']: ort_inputs = { ort_session.get_inputs()[0].name: x[0], ort_session.get_inputs()[1].name: x[1] } else: ort_inputs = { ort_session.get_inputs()[0].name: x[0], ort_session.get_inputs()[1].name: x[1], ort_session.get_inputs()[2].name: x[2] } if opt.bert_use_pos: ort_inputs[ort_session.get_inputs()[3].name] = x[3] if opt.use_crf: logits, prediction = ort_session.run(None, ort_inputs) prediction = to_device(torch.tensor(prediction), opt.device) logits = to_device(torch.tensor(logits), opt.device) else: logits = ort_session.run(None, ort_inputs)[0] logits = to_device(torch.tensor(logits), opt.device) else: if opt.use_crf: logits, prediction = model(x) else: logits = model(x) if preds is None: if opt.use_crf: preds = to_numpy(prediction) else: preds = to_numpy(logits) ys = to_numpy(y) else: if opt.use_crf: preds = np.append(preds, to_numpy(prediction), axis=0) else: preds = np.append(preds, to_numpy(logits), axis=0) ys = np.append(ys, to_numpy(y), axis=0) cur_examples = y.size(0) total_examples += cur_examples if i == 0: # first one may take longer time, so ignore in computing duration. first_time = float((time.time() - first_time) * 1000) first_examples = cur_examples if opt.num_examples != 0 and total_examples >= opt.num_examples: logger.info("[Stop Evaluation] : up to the {} examples".format( total_examples)) break duration_time = float((time.time() - start_time) * 1000) if i != 0: total_duration_time += duration_time ''' logger.info("[Elapsed Time] : {}ms".format(duration_time)) ''' whole_time = float((time.time() - whole_st_time) * 1000) avg_time = (whole_time - first_time) / (total_examples - first_examples) if not opt.use_crf: preds = np.argmax(preds, axis=2) # compute measure using seqeval labels = model.labels ys_lbs = [[] for _ in range(ys.shape[0])] preds_lbs = [[] for _ in range(ys.shape[0])] pad_label_id = config['pad_label_id'] for i in range(ys.shape[0]): # foreach sentence for j in range(ys.shape[1]): # foreach token if ys[i][j] != pad_label_id: ys_lbs[i].append(labels[ys[i][j]]) preds_lbs[i].append(labels[preds[i][j]]) ret = { "precision": precision_score(ys_lbs, preds_lbs), "recall": recall_score(ys_lbs, preds_lbs), "f1": f1_score(ys_lbs, preds_lbs), "report": classification_report(ys_lbs, preds_lbs, digits=4), } print(ret['report']) f1 = ret['f1'] # write predicted labels to file default_label = config['default_label'] write_prediction(opt, ys, preds, labels, pad_label_id, default_label) logger.info("[F1] : {}, {}".format(f1, total_examples)) logger.info("[Elapsed Time] : {} examples, {}ms, {}ms on average".format( total_examples, whole_time, avg_time)) logger.info( "[Elapsed Time(total_duration_time, average)] : {}ms, {}ms".format( total_duration_time, total_duration_time / (total_examples - 1)))
def get_grads(self): """ Returns the current gradient """ return deepcopy( np.hstack([to_numpy(v.grad).flatten() for v in self.parameters()]))
def get_params(self): """ Returns parameters of the actor """ return deepcopy( np.hstack([to_numpy(v).flatten() for v in self.parameters()]))
def evaluate(args): # set config config = load_config(args) if args.num_threads > 0: torch.set_num_threads(args.num_threads) config['args'] = args logger.info("%s", config) # set path set_path(config) # prepare test dataset test_loader = prepare_datasets(config) # load pytorch model checkpoint checkpoint = load_checkpoint(args.model_path, device=args.device) # prepare model and load parameters model = load_model(config, checkpoint) model.eval() # convert to onnx format if args.convert_onnx: # FIXME not working for --use_crf batch = next(iter(test_loader)) if config['emb_class'] not in ['glove', 'elmo']: x, y, gy = batch else: x, y = batch x = to_device(x, args.device) convert_onnx(config, model, x) check_onnx(config) logger.info("[ONNX model saved] : {}".format(args.onnx_path)) # quantize onnx if args.quantize_onnx: quantize_onnx(args.onnx_path, args.quantized_onnx_path) logger.info("[Quantized ONNX model saved] : {}".format( args.quantized_onnx_path)) return # load onnx model for using onnxruntime if args.enable_ort: import onnxruntime as ort sess_options = ort.SessionOptions() sess_options.inter_op_num_threads = args.num_threads sess_options.intra_op_num_threads = args.num_threads ort_session = ort.InferenceSession(args.onnx_path, sess_options=sess_options) # enable to use dynamic quantized model (pytorch>=1.3.0) if args.enable_dqm: model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8) print(model) # evaluation preds = None ys = None gpreds = None gys = None n_batches = len(test_loader) total_examples = 0 whole_st_time = time.time() first_time = time.time() first_examples = 0 total_duration_time = 0.0 with torch.no_grad(): for i, batch in enumerate(tqdm(test_loader, total=n_batches)): start_time = time.time() if config['emb_class'] not in ['glove', 'elmo']: x, y, gy = batch gy = to_device(gy, args.device) else: x, y = batch x = to_device(x, args.device) y = to_device(y, args.device) if args.enable_ort: ort_inputs = build_onnx_input(config, ort_session, x) if args.use_crf: # FIXME not working for --use_crf if args.bert_use_mtl: logits, prediction, glogits = ort_session.run( None, ort_inputs) glogits = to_device(torch.tensor(glogits), args.device) else: logits, prediction = ort_session.run(None, ort_inputs) prediction = to_device(torch.tensor(prediction), args.device) logits = to_device(torch.tensor(logits), args.device) else: if args.bert_use_mtl: logits, glogits = ort_session.run(None, ort_inputs) glogits = to_device(torch.tensor(glogits), args.device) else: logits = ort_session.run(None, ort_inputs)[0] logits = to_device(torch.tensor(logits), args.device) logits = torch.softmax(logits, dim=-1) else: if args.use_crf: if args.bert_use_mtl: logits, prediction, glogits = model(x) else: logits, prediction = model(x) else: if args.bert_use_mtl: logits, glogits = model(x) else: logits = model(x) logits = torch.softmax(logits, dim=-1) if preds is None: if args.use_crf: preds = to_numpy(prediction) else: preds = to_numpy(logits) ys = to_numpy(y) else: if args.use_crf: preds = np.append(preds, to_numpy(prediction), axis=0) else: preds = np.append(preds, to_numpy(logits), axis=0) ys = np.append(ys, to_numpy(y), axis=0) if args.bert_use_mtl: glogits = torch.softmax(glogits, dim=-1) if gpreds is None: gpreds = to_numpy(glogits) gys = to_numpy(gy) else: gpreds = np.append(gpreds, to_numpy(glogits), axis=0) gys = np.append(gys, to_numpy(gy), axis=0) cur_examples = y.size(0) total_examples += cur_examples if i == 0: # first one may take longer time, so ignore in computing duration. first_time = float((time.time() - first_time) * 1000) first_examples = cur_examples if args.num_examples != 0 and total_examples >= args.num_examples: logger.info("[Stop Evaluation] : up to the {} examples".format( total_examples)) break duration_time = float((time.time() - start_time) * 1000) if i != 0: total_duration_time += duration_time ''' logger.info("[Elapsed Time] : {}ms".format(duration_time)) ''' whole_time = float((time.time() - whole_st_time) * 1000) avg_time = (whole_time - first_time) / (total_examples - first_examples) # generate report for token classification if not args.use_crf: preds = np.argmax(preds, axis=2) # compute measure using seqeval labels = config['labels'] ys_lbs = [[] for _ in range(ys.shape[0])] preds_lbs = [[] for _ in range(ys.shape[0])] pad_label_id = config['pad_label_id'] for i in range(ys.shape[0]): # foreach sentence for j in range(ys.shape[1]): # foreach token if ys[i][j] != pad_label_id: ys_lbs[i].append(labels[ys[i][j]]) preds_lbs[i].append(labels[preds[i][j]]) ret = { "precision": precision_score(ys_lbs, preds_lbs), "recall": recall_score(ys_lbs, preds_lbs), "f1": f1_score(ys_lbs, preds_lbs), "report": classification_report(ys_lbs, preds_lbs, digits=4), } print(ret['report']) # write predicted labels to file write_prediction(config, model, ys, preds, labels) # generate report for sequence classification if args.bert_use_mtl: glabels = config['glabels'] glabel_names = [ v for k, v in sorted(glabels.items(), key=lambda x: x[0]) ] glabel_ids = [k for k in glabels.keys()] gpreds_ids = np.argmax(gpreds, axis=1) try: g_report = sequence_classification_report( gys, gpreds_ids, target_names=glabel_names, labels=glabel_ids, digits=4) g_report_dict = sequence_classification_report( gys, gpreds_ids, target_names=glabel_names, labels=glabel_ids, output_dict=True) g_matrix = confusion_matrix(gys, gpreds_ids) ret['g_report'] = g_report ret['g_report_dict'] = g_report_dict ret['g_f1'] = g_report_dict['micro avg']['f1-score'] ret['g_matrix'] = g_matrix except Exception as e: logger.warn(str(e)) print(ret['g_report']) print(ret['g_f1']) print(ret['g_matrix']) logger.info("[sequence classification F1] : {}, {}".format( ret['g_f1'], total_examples)) # write predicted glabels to file write_gprediction(args, gpreds, glabels) logger.info("[token classification F1] : {}, {}".format( ret['f1'], total_examples)) logger.info("[Elapsed Time] : {} examples, {}ms, {}ms on average".format( total_examples, whole_time, avg_time)) logger.info( "[Elapsed Time(total_duration_time, average)] : {}ms, {}ms".format( total_duration_time, total_duration_time / (total_examples - 1)))
def test_kron(): """Test kron, vec and vecr identities""" torch.set_default_dtype(torch.float64) a = torch.tensor([1, 2, 3, 4]).reshape(2, 2) b = torch.tensor([5, 6, 7, 8]).reshape(2, 2) u.check_close(u.Kron(a, b).trace(), 65) a = torch.tensor([[2., 7, 9], [1, 9, 8], [2, 7, 5]]) b = torch.tensor([[6., 6, 1], [10, 7, 7], [7, 10, 10]]) Ck = u.Kron(a, b) u.check_close(a.flatten().norm() * b.flatten().norm(), Ck.frobenius_norm()) u.check_close(Ck.frobenius_norm(), 4 * math.sqrt(11635.)) Ci = [[ 0, 5 / 102, -(7 / 204), 0, -(70 / 561), 49 / 561, 0, 125 / 1122, -(175 / 2244) ], [ 1 / 20, -(53 / 1020), 8 / 255, -(7 / 55), 371 / 2805, -(224 / 2805), 5 / 44, -(265 / 2244), 40 / 561 ], [ -(1 / 20), 3 / 170, 3 / 170, 7 / 55, -(42 / 935), -(42 / 935), -(5 / 44), 15 / 374, 15 / 374 ], [ 0, -(5 / 102), 7 / 204, 0, 20 / 561, -(14 / 561), 0, 35 / 1122, -(49 / 2244) ], [ -(1 / 20), 53 / 1020, -(8 / 255), 2 / 55, -(106 / 2805), 64 / 2805, 7 / 220, -(371 / 11220), 56 / 2805 ], [ 1 / 20, -(3 / 170), -(3 / 170), -(2 / 55), 12 / 935, 12 / 935, -(7 / 220), 21 / 1870, 21 / 1870 ], [0, 5 / 102, -(7 / 204), 0, 0, 0, 0, -(5 / 102), 7 / 204], [ 1 / 20, -(53 / 1020), 8 / 255, 0, 0, 0, -(1 / 20), 53 / 1020, -(8 / 255) ], [ -(1 / 20), 3 / 170, 3 / 170, 0, 0, 0, 1 / 20, -(3 / 170), -(3 / 170) ]] C = Ck.expand() C0 = u.to_numpy(C) Ci = torch.tensor(Ci) u.check_close(C @ Ci @ C, C) u.check_close(Ck.inv().expand(), torch.inverse(Ck.expand())) u.check_close(Ck.inv().expand_vec(), torch.inverse(Ck.expand_vec())) u.check_close(Ck.pinv().expand(), torch.pinverse(Ck.expand())) u.check_close(linalg.pinv(C0), Ci, rtol=1e-5, atol=1e-6) u.check_close(torch.pinverse(C), Ci, rtol=1e-5, atol=1e-6) u.check_close(Ck.inv().expand(), Ci, rtol=1e-5, atol=1e-6) u.check_close(Ck.pinv().expand(), Ci, rtol=1e-5, atol=1e-6) Ck2 = u.Kron(b, 2 * a) u.check_close((Ck @ Ck2).expand(), Ck.expand() @ Ck2.expand()) u.check_close((Ck @ Ck2).expand_vec(), Ck.expand_vec() @ Ck2.expand_vec()) d2 = 3 d1 = 2 G = torch.randn(d2, d1) g = u.vec(G) H = u.Kron(u.random_cov(d1), u.random_cov(d2)) Gt = G.t() gt = g.reshape(1, -1) vecX = u.Vec([1, 2, 3, 4], shape=(2, 2)) K = u.Kron([[5, 6], [7, 8]], [[9, 10], [11, 12]]) u.check_equal(vecX @ K, [644, 706, 748, 820]) u.check_equal(K @ vecX, [543, 655, 737, 889]) u.check_equal(u.matmul(vecX @ K, vecX), 7538) u.check_equal(vecX @ (vecX @ K), 7538) u.check_equal(vecX @ vecX, 30) vecX = u.Vec([1, 2], shape=(1, 2)) K = u.Kron([[5]], [[9, 10], [11, 12]]) u.check_equal(vecX.norm()**2, 5) # check kronecker rules X = torch.tensor([[1., 2], [3, 4]]) A = torch.tensor([[5., 6], [7, 8]]) B = torch.tensor([[9., 10], [11, 12]]) x = u.Vec(X) # kron/vec/vecr identities u.check_equal(u.Vec(A @ X @ B), x @ u.Kron(B, A.t())) u.check_equal(u.Vec(A @ X @ B), u.Kron(B.t(), A) @ x) u.check_equal(u.Vecr(A @ X @ B), u.Kron(A, B.t()) @ u.Vecr(X)) u.check_equal(u.Vecr(A @ X @ B), u.Vecr(X) @ u.Kron(A.t(), B)) def extra_checks(A, X, B): x = u.Vec(X) u.check_equal(u.Vec(A @ X @ B), x @ u.Kron(B, A.t())) u.check_equal(u.Vec(A @ X @ B), u.Kron(B.t(), A) @ x) u.check_equal(u.Vecr(A @ X @ B), u.Kron(A, B.t()) @ u.Vecr(X)) u.check_equal(u.Vecr(A @ X @ B), u.Vecr(X) @ u.Kron(A.t(), B)) u.check_equal(u.Vecr(A @ X @ B), u.Vecr(X) @ u.Kron(A.t(), B).normal_form()) u.check_equal(u.Vecr(A @ X @ B), u.matmul(u.Kron(A, B.t()).normal_form(), u.Vecr(X))) u.check_equal(u.Vec(A @ X @ B), u.matmul(u.Kron(B.t(), A).normal_form(), x)) u.check_equal(u.Vec(A @ X @ B), x @ u.Kron(B, A.t()).normal_form()) u.check_equal(u.Vec(A @ X @ B), x.normal_form() @ u.Kron(B, A.t()).normal_form()) u.check_equal(u.Vec(A @ X @ B), u.Kron(B.t(), A).normal_form() @ x.normal_form()) u.check_equal(u.Vecr(A @ X @ B), u.Kron(A, B.t()).normal_form() @ u.Vecr(X).normal_form()) u.check_equal(u.Vecr(A @ X @ B), u.Vecr(X).normal_form() @ u.Kron(A.t(), B).normal_form()) # shape checks d1, d2 = 3, 4 extra_checks(torch.ones((d1, d1)), torch.ones((d1, d2)), torch.ones((d2, d2))) A = torch.rand(d1, d1) B = torch.rand(d2, d2) #x = torch.rand((d1*d2)) #X = x.t().reshape(d1, d2) # X = torch.rand((d1, d2)) # x = u.vec(X) x = torch.rand((d1 * d2))
def test_lyapunov(): """Test that scipy lyapunov solver works correctly.""" d = 2 n = 3 torch.set_default_dtype(torch.float32) model = Net(d) w0 = torch.tensor([[1, 2]]).float() assert w0.shape[1] == d model.w.weight.data.copy_(w0) X = torch.tensor([[-2, 0, 2], [-1, 1, 3]]).float() assert X.shape[0] == d assert X.shape[1] == n Y = torch.tensor([[0, 1, 2]]).float() assert Y.shape[1] == X.shape[1] data = X.t() # PyTorch expects batch dimension first target = Y.t() assert data.shape[0] == n output = model(data) # residuals, aka e residuals = output - Y.t() def compute_loss(residuals_): return torch.sum(residuals_ * residuals_) / (2 * n) loss = compute_loss(residuals) assert loss - 8.83333 < 1e-5, torch.norm(loss) - 8.83333 # use learning rate 0 to avoid changing parameter vector optim_kwargs = dict( lr=0, momentum=0, weight_decay=0, l2_reg=0, bias_correction=False, acc_steps=1, curv_type="Cov", curv_shapes={"Linear": "Kron"}, momentum_type="preconditioned", ) curv_args = dict(damping=1, ema_decay=1) # todo: damping optimizer = SecondOrderOptimizer(model, **optim_kwargs, curv_kwargs=curv_args) def backward(last_layer: str) -> Callable: """Creates closure that backpropagates either from output layer or from loss layer""" def closure() -> Tuple[Optional[torch.Tensor], torch.Tensor]: optimizer.zero_grad() output = model(data) if last_layer == "output": output.backward(torch.ones_like(target)) return None, output elif last_layer == 'loss': loss = compute_loss(output - target) loss.backward() return loss, output else: assert False, 'last layer must be "output" or "loss"' return closure # loss = compute_loss(output - Y.t()) # loss.backward() loss, output = optimizer.step(closure=backward('loss')) J = X.t() A = model.w.data_input B = model.w.grad_output * n G = residuals.repeat(1, d) * J losses = torch.stack([compute_loss(r) for r in residuals]) g = G.sum(dim=0) / n efisher = G.t() @ G / n sigma = efisher - u.outer(g, g) loss2 = (residuals * residuals).sum() / (2 * n) H = J.t() @ J / n noise_variance = torch.trace(H.inverse() @ sigma) # H is not quite symmetric, make it so H = H + H.t() # Slow way p_sigma = u.lyapunov_lstsq(H, sigma) sigma0 = u.to_numpy(sigma) H0 = u.to_numpy(H) # Alternative faster way p_sigma2 = scipy.linalg.solve_lyapunov(H0, sigma0) print(f"Error 1: {np.max(abs(H0 @ p_sigma2 + p_sigma2 @ H0 - sigma0))}") u.check_close(p_sigma, p_sigma2) # alternative through SVD p_sigma3 = lyapunov_svd(torch.tensor(H0), torch.tensor(sigma0)) u.check_close(p_sigma2, p_sigma3) # alternative through evals p_sigma4 = u.lyapunov_spectral(torch.tensor(H0), torch.tensor(sigma0)) u.check_close(p_sigma2, p_sigma4)
def train_one_step(self, data_i, total_steps_so_far): images_minibatch, labels = self.prepare_images(data_i) c_losses = self.train_classifier_one_step(images_minibatch, labels) self.adjust_lr_if_necessary(total_steps_so_far) return util.to_numpy(c_losses)
def main(): np.random.seed(42) parser = ArgumentParser() parser.add_argument('--base_dir', type=str, default='./EDD2020/') # './' parser.add_argument('--model', type=str, default='unetresnet') args = parser.parse_args() base_dir = args.base_dir model_name = args.model success = create_dir(base_dir + 'resized_masks/') if success: resize_my_images(base_dir + 'EDD2020_release-I_2020-01-15/masks/', base_dir + 'resized_masks/', is_masks=True) success = create_dir(base_dir + 'resized_images/') success &= create_dir(base_dir + 'resized_bboxs/') if success: resize_my_images( base_dir + 'EDD2020_release-I_2020-01-15/originalImages/', base_dir + 'resized_images/', is_masks=False, bboxs_src=base_dir + 'EDD2020_release-I_2020-01-15/bbox/', bboxs_dst=base_dir + 'resized_bboxs/') loader = get_edd_loader(base_dir, shuffle_dataset=True) model = get_model(model_name, in_channel=3, n_classes=5) optimizer_func = optim.Adam(model.parameters(), lr=1e-4) scheduler = lr_scheduler.StepLR(optimizer_func, step_size=10, gamma=0.1) trainer = Trainer(model, optimizer=optimizer_func, scheduler=scheduler) trainer.train_model(loader, num_epochs=30) create_dir(base_dir + 'test/') create_dir(base_dir + 'test/images') create_dir(base_dir + 'test/masks') create_dir(base_dir + 'test/masks_pred') create_dir(base_dir + 'test/bboxs') create_dir(base_dir + 'test/bboxs_pred') create_dir(base_dir + 'test/plots') metrics = defaultdict(float) plot = Plot(base_dir + 'test/plots/') index = 0 for epoch, (images, bboxs, masks) in enumerate(loader['test']): masks_preds = trainer.predict(images) # print('B ', bboxs) bboxs = bbox_tensor_to_bbox(bboxs.squeeze(0)) # print('A ', bboxs) bboxs_preds = compute_bboxs_from_masks(masks_preds.squeeze(0)) plot.plot_image_truemask_predictedmask(images, masks, masks_preds, index) plot.plot_image_truebbox_predictedbbox(images, bboxs, bboxs_preds, index) calc_loss(torch.Tensor(masks_preds), masks, metrics) image, mask, mask_pred = images[0], masks[0], masks_preds[0] save_to_tif(base_dir + 'test/masks/MASK_{:03d}.tif'.format(index), to_numpy(mask)) save_to_tif( base_dir + 'test/masks_pred/MASK_PRED_{:03d}.tif'.format(index), mask_pred) image = image.detach().cpu().numpy().swapaxes(0, 2).swapaxes(0, 1) plt.imsave(base_dir + 'test/images/IMG_{:03d}.png'.format(index), image) index += 1 computed_metrics = compute_metrics(metrics, epoch + 1) print_metrics(computed_metrics, 'test')
def main(): parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser.add_argument('--batch-size', type=int, default=64, metavar='N', help='input batch size for training (default: 64)') parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', help='input batch size for testing (default: 1000)') parser.add_argument('--epochs', type=int, default=10, metavar='N', help='number of epochs to train (default: 10)') parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') parser.add_argument('--log-interval', type=int, default=10, metavar='N', help='how many batches to wait before logging training status') parser.add_argument('--save-model', action='store_true', default=False, help='For Saving the current Model') parser.add_argument('--wandb', type=int, default=0, help='log to weights and biases') parser.add_argument('--autograd_check', type=int, default=0, help='autograd correctness checks') parser.add_argument('--logdir', type=str, default='/tmp/runs/curv_train_tiny/run') parser.add_argument('--nonlin', type=int, default=1, help="whether to add ReLU nonlinearity between layers") parser.add_argument('--bias', type=int, default=1, help="whether to add bias between layers") parser.add_argument('--layer', type=int, default=-1, help="restrict updates to this layer") parser.add_argument('--data_width', type=int, default=28) parser.add_argument('--targets_width', type=int, default=28) parser.add_argument('--hess_samples', type=int, default=1, help='number of samples when sub-sampling outputs, 0 for exact hessian') parser.add_argument('--hess_kfac', type=int, default=0, help='whether to use KFAC approximation for hessian') parser.add_argument('--compute_rho', type=int, default=0, help='use expensive method to compute rho') parser.add_argument('--skip_stats', type=int, default=0, help='skip all stats collection') parser.add_argument('--dataset_size', type=int, default=60000) parser.add_argument('--train_steps', type=int, default=100, help="this many train steps between stat collection") parser.add_argument('--stats_steps', type=int, default=1000000, help="total number of curvature stats collections") parser.add_argument('--full_batch', type=int, default=0, help='do stats on the whole dataset') parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--weight_decay', type=float, default=0) parser.add_argument('--momentum', type=float, default=0.9) parser.add_argument('--dropout', type=int, default=0) parser.add_argument('--swa', type=int, default=0) parser.add_argument('--lmb', type=float, default=1e-3) parser.add_argument('--train_batch_size', type=int, default=64) parser.add_argument('--stats_batch_size', type=int, default=10000) parser.add_argument('--stats_num_batches', type=int, default=1) parser.add_argument('--run_name', type=str, default='noname') parser.add_argument('--launch_blocking', type=int, default=0) parser.add_argument('--sampled', type=int, default=0) parser.add_argument('--curv', type=str, default='kfac', help='decomposition to use for curvature estimates: zero_order, kfac, isserlis or full') parser.add_argument('--log_spectra', type=int, default=0) u.seed_random(1) gl.args = parser.parse_args() args = gl.args u.seed_random(1) gl.project_name = 'train_ciresan' u.setup_logdir_and_event_writer(args.run_name) print(f"Logging to {gl.logdir}") d1 = 28 * 28 d = [784, 2500, 2000, 1500, 1000, 500, 10] # number of samples per datapoint. Used to normalize kfac model = u.SimpleFullyConnected2(d, nonlin=args.nonlin, bias=args.bias, dropout=args.dropout) model = model.to(gl.device) autograd_lib.register(model) assert args.dataset_size >= args.stats_batch_size optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, original_targets=True, dataset_size=args.dataset_size) train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True, drop_last=True) train_iter = u.infinite_iter(train_loader) assert not args.full_batch, "fixme: validation still uses stats_iter" if not args.full_batch: stats_loader = torch.utils.data.DataLoader(dataset, batch_size=args.stats_batch_size, shuffle=True, drop_last=True) stats_iter = u.infinite_iter(stats_loader) else: stats_iter = None test_dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, train=False, original_targets=True, dataset_size=args.dataset_size) test_eval_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.stats_batch_size, shuffle=False, drop_last=False) train_eval_loader = torch.utils.data.DataLoader(dataset, batch_size=args.stats_batch_size, shuffle=False, drop_last=False) loss_fn = torch.nn.CrossEntropyLoss() autograd_lib.add_hooks(model) autograd_lib.disable_hooks() gl.token_count = 0 last_outer = 0 for step in range(args.stats_steps): epoch = gl.token_count // 60000 lr = optimizer.param_groups[0]['lr'] print('token_count', gl.token_count) if last_outer: u.log_scalars({"time/outer": 1000 * (time.perf_counter() - last_outer)}) print(f'time: {time.perf_counter() - last_outer:.2f}') last_outer = time.perf_counter() with u.timeit("validate"): val_accuracy, val_loss = validate(model, test_eval_loader, f'test (epoch {epoch})') train_accuracy, train_loss = validate(model, train_eval_loader, f'train (epoch {epoch})') # save log metrics = {'epoch': epoch, 'val_accuracy': val_accuracy, 'val_loss': val_loss, 'train_loss': train_loss, 'train_accuracy': train_accuracy, 'lr': optimizer.param_groups[0]['lr'], 'momentum': optimizer.param_groups[0].get('momentum', 0)} u.log_scalars(metrics) def mom_update(buffer, val): buffer *= 0.9 buffer += val * 0.1 if not args.skip_stats: # number of samples passed through n = args.stats_batch_size * args.stats_num_batches # quanti forward_stats = defaultdict(lambda: AttrDefault(float)) hessians = defaultdict(lambda: AttrDefault(float)) jacobians = defaultdict(lambda: AttrDefault(float)) fishers = defaultdict(lambda: AttrDefault(float)) # empirical fisher/gradient quad_fishers = defaultdict(lambda: AttrDefault(float)) # gradient statistics that depend on fisher (4th order moments) train_regrets = defaultdict(list) test_regrets1 = defaultdict(list) test_regrets2 = defaultdict(list) train_regrets_opt = defaultdict(list) test_regrets_opt = defaultdict(list) cosines = defaultdict(list) dot_products = defaultdict(list) hessians_histograms = defaultdict(lambda: AttrDefault(u.MyList)) jacobians_histograms = defaultdict(lambda: AttrDefault(u.MyList)) fishers_histograms = defaultdict(lambda: AttrDefault(u.MyList)) quad_fishers_histograms = defaultdict(lambda: AttrDefault(u.MyList)) current = None current_histograms = None for i in range(args.stats_num_batches): activations = {} backprops = {} def save_activations(layer, A, _): activations[layer] = A forward_stats[layer].AA += torch.einsum("ni,nj->ij", A, A) print('forward') with u.timeit("stats_forward"): with autograd_lib.module_hook(save_activations): data, targets = next(stats_iter) output = model(data) loss = loss_fn(output, targets) * len(output) def compute_stats(layer, _, B): A = activations[layer] if current == fishers: backprops[layer] = B # about 27ms per layer with u.timeit('compute_stats'): current[layer].BB += torch.einsum("ni,nj->ij", B, B) # TODO(y): index consistency current[layer].diag += torch.einsum("ni,nj->ij", B * B, A * A) current[layer].BA += torch.einsum("ni,nj->ij", B, A) current[layer].a += torch.einsum("ni->i", A) current[layer].b += torch.einsum("nk->k", B) current[layer].norm2 += ((A * A).sum(dim=1) * (B * B).sum(dim=1)).sum() # compute curvatures in direction of all gradiennts if current is fishers: assert args.stats_num_batches == 1, "not tested on more than one stats step, currently reusing aggregated moments" hess = hessians[layer] jac = jacobians[layer] Bh, Ah = B @ hess.BB / n, A @ forward_stats[layer].AA / n Bj, Aj = B @ jac.BB / n, A @ forward_stats[layer].AA / n norms = ((A * A).sum(dim=1) * (B * B).sum(dim=1)) current[layer].min_norm2 = min(norms) current[layer].median_norm2 = torch.median(norms) current[layer].max_norm2 = max(norms) norms2_hess = ((Ah * A).sum(dim=1) * (Bh * B).sum(dim=1)) norms2_jac = ((Aj * A).sum(dim=1) * (Bj * B).sum(dim=1)) current[layer].norm += norms.sum() current_histograms[layer].norms.extend(torch.sqrt(norms)) current[layer].curv_hess += (skip_nans(norms2_hess / norms)).sum() current_histograms[layer].curv_hess.extend(skip_nans(norms2_hess / norms)) current[layer].curv_hess_max += (skip_nans(norms2_hess / norms)).max() current[layer].curv_hess_median += (skip_nans(norms2_hess / norms)).median() current_histograms[layer].curv_jac.extend(skip_nans(norms2_jac / norms)) current[layer].curv_jac += (skip_nans(norms2_jac / norms)).sum() current[layer].curv_jac_max += (skip_nans(norms2_jac / norms)).max() current[layer].curv_jac_median += (skip_nans(norms2_jac / norms)).median() current[layer].a_sparsity += torch.sum(A <= 0).float() / A.numel() current[layer].b_sparsity += torch.sum(B <= 0).float() / B.numel() current[layer].mean_activation += torch.mean(A) current[layer].mean_activation2 += torch.mean(A*A) current[layer].mean_backprop = torch.mean(B) current[layer].mean_backprop2 = torch.mean(B*B) current[layer].norms_hess += torch.sqrt(norms2_hess).sum() current_histograms[layer].norms_hess.extend(torch.sqrt(norms2_hess)) current[layer].norms_jac += norms2_jac.sum() current_histograms[layer].norms_jac.extend(torch.sqrt(norms2_jac)) normalized_moments = copy.copy(hessians[layer]) normalized_moments.AA = forward_stats[layer].AA normalized_moments = u.divide_attributes(normalized_moments, n) train_regrets_ = autograd_lib.offset_losses(A, B, alpha=lr, offset=0, m=normalized_moments, approx=args.curv) test_regrets1_ = autograd_lib.offset_losses(A, B, alpha=lr, offset=1, m=normalized_moments, approx=args.curv) test_regrets2_ = autograd_lib.offset_losses(A, B, alpha=lr, offset=2, m=normalized_moments, approx=args.curv) test_regrets_opt_ = autograd_lib.offset_losses(A, B, alpha=None, offset=2, m=normalized_moments, approx=args.curv) train_regrets_opt_ = autograd_lib.offset_losses(A, B, alpha=None, offset=0, m=normalized_moments, approx=args.curv) cosines_ = autograd_lib.offset_cosines(A, B) train_regrets[layer].extend(train_regrets_) test_regrets1[layer].extend(test_regrets1_) test_regrets2[layer].extend(test_regrets2_) train_regrets_opt[layer].extend(train_regrets_opt_) test_regrets_opt[layer].extend(test_regrets_opt_) cosines[layer].extend(cosines_) dot_products[layer].extend(autograd_lib.offset_dotprod(A, B)) # statistics of the form g.Sigma.g elif current == quad_fishers: hess = hessians[layer] sigma = fishers[layer] jac = jacobians[layer] Bs, As = B @ sigma.BB / n, A @ forward_stats[layer].AA / n Bh, Ah = B @ hess.BB / n, A @ forward_stats[layer].AA / n Bj, Aj = B @ jac.BB / n, A @ forward_stats[layer].AA / n norms = ((A * A).sum(dim=1) * (B * B).sum(dim=1)) norms2_hess = ((Ah * A).sum(dim=1) * (Bh * B).sum(dim=1)) norms2_jac = ((Aj * A).sum(dim=1) * (Bj * B).sum(dim=1)) norms_sigma = ((As * A).sum(dim=1) * (Bs * B).sum(dim=1)) current[layer].norm += norms.sum() # TODO(y) remove, redundant with norm2 above current[layer].curv_sigma += (skip_nans(norms_sigma / norms)).sum() current[layer].curv_sigma_max = skip_nans(norms_sigma / norms).max() current[layer].curv_sigma_median = skip_nans(norms_sigma / norms).median() current[layer].curv_hess += skip_nans(norms2_hess / norms).sum() current[layer].curv_hess_max += skip_nans(norms2_hess / norms).max() current[layer].lyap_hess_mean += skip_nans(norms_sigma / norms2_hess).mean() current[layer].lyap_hess_max = max(skip_nans(norms_sigma/norms2_hess)) current[layer].lyap_jac_mean += skip_nans(norms_sigma / norms2_jac).mean() current[layer].lyap_jac_max = max(skip_nans(norms_sigma/norms2_jac)) print('backward') with u.timeit("backprop_H"): with autograd_lib.module_hook(compute_stats): current = hessians current_histograms = hessians_histograms autograd_lib.backward_hessian(output, loss='CrossEntropy', sampled=args.sampled, retain_graph=True) # 600 ms current = jacobians current_histograms = jacobians_histograms autograd_lib.backward_jacobian(output, sampled=args.sampled, retain_graph=True) # 600 ms current = fishers current_histograms = fishers_histograms model.zero_grad() loss.backward(retain_graph=True) # 60 ms current = quad_fishers current_histograms = quad_fishers_histograms model.zero_grad() loss.backward() # 60 ms print('summarize') for (i, layer) in enumerate(model.layers): stats_dict = {'hessian': hessians, 'jacobian': jacobians, 'fisher': fishers} # evaluate stats from # https://app.wandb.ai/yaroslavvb/train_ciresan/runs/425pu650?workspace=user-yaroslavvb for stats_name in stats_dict: s = AttrDict() stats = stats_dict[stats_name][layer] for key in forward_stats[layer]: # print(f'copying {key} in {stats_name}, {layer}') try: assert stats[key] == float() except: f"Trying to overwrite {key} in {stats_name}, {layer}" stats[key] = forward_stats[layer][key] diag: torch.Tensor = stats.diag / n # jacobian: # curv in direction of gradient goes down to roughly 0.3-1 # maximum curvature goes up to 1000-2000 # # Hessian: # max curv goes down to 1, in direction of gradient 0.0001 s.diag_l2 = torch.max(diag) # 40 - 3000 smaller than kfac l2 for jac s.diag_fro = torch.norm( diag) # jacobian grows to 0.5-1.5, rest falls, layer-5 has phase transition, layer-4 also s.diag_trace = diag.sum() # jacobian grows 0-1000 (first), 0-150 (last). Almost same as kfac_trace (771 vs 810 kfac). Jacobian has up/down phase transition s.diag_average = diag.mean() # normalize for mean loss BB = stats.BB / n AA = stats.AA / n # A_evals, _ = torch.symeig(AA) # averaging 120ms per hit, 90 hits # B_evals, _ = torch.symeig(BB) # s.kfac_l2 = torch.max(A_evals) * torch.max(B_evals) # 60x larger than diag_l2. layer0/hess has down/up phase transition. layer5/jacobian has up/down phase transition s.kfac_trace = torch.trace(AA) * torch.trace(BB) # 0/hess down/up tr, 5/jac sharp phase transition s.kfac_fro = torch.norm(stats.AA) * torch.norm( stats.BB) # 0/hess has down/up tr, 5/jac up/down transition # s.kfac_erank = s.kfac_trace / s.kfac_l2 # first layer has 25, rest 15, all layers go down except last, last noisy # s.kfac_erank_fro = s.kfac_trace / s.kfac_fro / max(stats.BA.shape) s.diversity = (stats.norm2 / n) / u.norm_squared( stats.BA / n) # gradient diversity. Goes up 3x. Bottom layer has most diversity. Jacobian diversity much less noisy than everythingelse # discrepancy of KFAC based on exact values of diagonal approximation # average difference normalized by average diagonal magnitude diag_kfac = torch.einsum('ll,ii->li', BB, AA) s.kfac_error = (torch.abs(diag_kfac - diag)).mean() / torch.mean(diag.abs()) u.log_scalars(u.nest_stats(f'layer-{i}/{stats_name}', s)) # openai batch size stat s = AttrDict() hess = hessians[layer] jac = jacobians[layer] fish = fishers[layer] quad_fish = quad_fishers[layer] # the following check passes, but is expensive # if args.stats_num_batches == 1: # u.check_close(fisher[layer].BA, layer.weight.grad) def trsum(A, B): return (A * B).sum() # computes tr(AB') grad = fishers[layer].BA / n s.grad_fro = torch.norm(grad) # get norms s.lyap_hess_max = quad_fish.lyap_hess_max s.lyap_hess_ave = quad_fish.lyap_hess_sum / n s.lyap_jac_max = quad_fish.lyap_jac_max s.lyap_jac_ave = quad_fish.lyap_jac_sum / n s.hess_trace = hess.diag.sum() / n s.jac_trace = jac.diag.sum() / n # Version 1 of Jain stochastic rates, use Hessian for curvature b = args.train_batch_size s.hess_curv = trsum((hess.BB / n) @ grad @ (hess.AA / n), grad) / trsum(grad, grad) s.jac_curv = trsum((jac.BB / n) @ grad @ (jac.AA / n), grad) / trsum(grad, grad) # compute gradient noise statistics # fish.BB has /n factor twice, hence don't need extra /n on fish.AA # after sampling, hess_noise,jac_noise became 100x smaller, but normalized is unaffected s.hess_noise = (trsum(hess.AA / n, fish.AA / n) * trsum(hess.BB / n, fish.BB / n)) s.jac_noise = (trsum(jac.AA / n, fish.AA / n) * trsum(jac.BB / n, fish.BB / n)) s.hess_noise_centered = s.hess_noise - trsum(hess.BB / n @ grad, grad @ hess.AA / n) s.jac_noise_centered = s.jac_noise - trsum(jac.BB / n @ grad, grad @ jac.AA / n) s.openai_gradient_noise = (fish.norms_hess / n) / trsum(hess.BB / n @ grad, grad @ hess.AA / n) s.mean_norm = torch.sqrt(fish.norm2) / n s.min_norm = torch.sqrt(fish.min_norm2) s.median_norm = torch.sqrt(fish.median_norm2) s.max_norm = torch.sqrt(fish.max_norm2) s.enorms = u.norm_squared(grad) s.a_sparsity = fish.a_sparsity s.b_sparsity = fish.b_sparsity s.mean_activation = fish.mean_activation s.msr_activation = torch.sqrt(fish.mean_activation2) s.mean_backprop = fish.mean_backprop s.msr_backprop = torch.sqrt(fish.mean_backprop2) s.norms_centered = fish.norm2 / n - u.norm_squared(grad) s.norms_hess = fish.norms_hess / n s.norms_jac = fish.norms_jac / n s.hess_curv_grad = fish.curv_hess / n # phase transition, hits minimum loss in layer 1, then starts going up. Other layers take longer to reach minimum. Decreases with depth. s.hess_curv_grad_max = fish.curv_hess_max # phase transition, hits minimum loss in layer 1, then starts going up. Other layers take longer to reach minimum. Decreases with depth. s.hess_curv_grad_median = fish.curv_hess_median # phase transition, hits minimum loss in layer 1, then starts going up. Other layers take longer to reach minimum. Decreases with depth. s.sigma_curv_grad = quad_fish.curv_sigma / n s.sigma_curv_grad_max = quad_fish.curv_sigma_max s.sigma_curv_grad_median = quad_fish.curv_sigma_median s.band_bottou = 0.5 * lr * s.sigma_curv_grad / s.hess_curv_grad s.band_bottou_stoch = 0.5 * lr * quad_fish.curv_ratio / n s.band_yaida = 0.25 * lr * s.mean_norm**2 s.band_yaida_centered = 0.25 * lr * s.norms_centered s.jac_curv_grad = fish.curv_jac / n # this one has much lower variance than jac_curv. Reaches peak at 10k steps, also kfac error reaches peak there. Decreases with depth except for last layer. s.jac_curv_grad_max = fish.curv_jac_max # this one has much lower variance than jac_curv. Reaches peak at 10k steps, also kfac error reaches peak there. Decreases with depth except for last layer. s.jac_curv_grad_median = fish.curv_jac_median # this one has much lower variance than jac_curv. Reaches peak at 10k steps, also kfac error reaches peak there. Decreases with depth except for last layer. # OpenAI gradient noise statistics s.hess_noise_normalized = s.hess_noise_centered / (fish.norms_hess / n) s.jac_noise_normalized = s.jac_noise / (fish.norms_jac / n) train_regrets_, test_regrets1_, test_regrets2_, train_regrets_opt_, test_regrets_opt_, cosines_, dot_products_ = (torch.stack(r[layer]) for r in (train_regrets, test_regrets1, test_regrets2, train_regrets_opt, test_regrets_opt, cosines, dot_products)) s.train_regret = train_regrets_.median() # use median because outliers make it hard to see the trend s.test_regret1 = test_regrets1_.median() s.test_regret2 = test_regrets2_.median() s.test_regret_opt = test_regrets_opt_.median() s.train_regret_opt = train_regrets_opt_.median() s.mean_dot_product = torch.mean(dot_products_) s.median_dot_product = torch.median(dot_products_) a = [1, 2, 3] s.median_cosine = cosines_.median() s.mean_cosine = cosines_.mean() # get learning rates L1 = s.hess_curv_grad / n L2 = s.jac_curv_grad / n diversity = (fish.norm2 / n) / u.norm_squared(grad) robust_diversity = (fish.norm2 / n) / fish.median_norm2 dotprod_diversity = fish.median_norm2 / s.median_dot_product s.lr1 = 2 / (L1 * diversity) s.lr2 = 2 / (L2 * diversity) s.lr3 = 2 / (L2 * robust_diversity) s.lr4 = 2 / (L2 * dotprod_diversity) hess_A = u.symeig_pos_evals(hess.AA / n) hess_B = u.symeig_pos_evals(hess.BB / n) fish_A = u.symeig_pos_evals(fish.AA / n) fish_B = u.symeig_pos_evals(fish.BB / n) jac_A = u.symeig_pos_evals(jac.AA / n) jac_B = u.symeig_pos_evals(jac.BB / n) u.log_scalars({f'layer-{i}/hessA_erank': erank(hess_A)}) u.log_scalars({f'layer-{i}/hessB_erank': erank(hess_B)}) u.log_scalars({f'layer-{i}/fishA_erank': erank(fish_A)}) u.log_scalars({f'layer-{i}/fishB_erank': erank(fish_B)}) u.log_scalars({f'layer-{i}/jacA_erank': erank(jac_A)}) u.log_scalars({f'layer-{i}/jacB_erank': erank(jac_B)}) gl.event_writer.add_histogram(f'layer-{i}/hist_hess_eig', u.outer(hess_A, hess_B).flatten(), gl.get_global_step()) gl.event_writer.add_histogram(f'layer-{i}/hist_fish_eig', u.outer(hess_A, hess_B).flatten(), gl.get_global_step()) gl.event_writer.add_histogram(f'layer-{i}/hist_jac_eig', u.outer(hess_A, hess_B).flatten(), gl.get_global_step()) s.hess_l2 = max(hess_A) * max(hess_B) s.jac_l2 = max(jac_A) * max(jac_B) s.fish_l2 = max(fish_A) * max(fish_B) s.hess_trace = hess.diag.sum() / n s.jain1_sto = 1/(s.hess_trace + 2 * s.hess_l2) s.jain1_det = 1/s.hess_l2 s.jain1_lr = (1 / b) * (1/s.jain1_sto) + (b - 1) / b * (1/s.jain1_det) s.jain1_lr = 2 / s.jain1_lr s.regret_ratio = ( train_regrets_opt_ / test_regrets_opt_).median() # ratio between train and test regret, large means overfitting u.log_scalars(u.nest_stats(f'layer-{i}', s)) # compute stats that would let you bound rho if i == 0: # only compute this once, for output layer hhh = hessians[model.layers[-1]].BB / n fff = fishers[model.layers[-1]].BB / n d = fff.shape[0] L = u.lyapunov_spectral(hhh, 2 * fff, cond=1e-8) L_evals = u.symeig_pos_evals(L) Lcheap = fff @ u.pinv(hhh, cond=1e-8) Lcheap_evals = u.eig_real(Lcheap) u.log_scalars({f'mismatch/rho': d/erank(L_evals)}) u.log_scalars({f'mismatch/rho_cheap': d/erank(Lcheap_evals)}) u.log_scalars({f'mismatch/diagonalizability': erank(L_evals)/erank(Lcheap_evals)}) # 1 means diagonalizable u.log_spectrum(f'mismatch/sigma', u.symeig_pos_evals(fff), loglog=False) u.log_spectrum(f'mismatch/hess', u.symeig_pos_evals(hhh), loglog=False) u.log_spectrum(f'mismatch/lyapunov', L_evals, loglog=True) u.log_spectrum(f'mismatch/lyapunov_cheap', Lcheap_evals, loglog=True) gl.event_writer.add_histogram(f'layer-{i}/hist_grad_norms', u.to_numpy(fishers_histograms[layer].norms.value()), gl.get_global_step()) gl.event_writer.add_histogram(f'layer-{i}/hist_grad_norms_hess', u.to_numpy(fishers_histograms[layer].norms_hess.value()), gl.get_global_step()) gl.event_writer.add_histogram(f'layer-{i}/hist_curv_jac', u.to_numpy(fishers_histograms[layer].curv_jac.value()), gl.get_global_step()) gl.event_writer.add_histogram(f'layer-{i}/hist_curv_hess', u.to_numpy(fishers_histograms[layer].curv_hess.value()), gl.get_global_step()) gl.event_writer.add_histogram(f'layer-{i}/hist_cosines', u.to_numpy(cosines[layer]), gl.get_global_step()) if args.log_spectra: with u.timeit('spectrum'): # 2/alpha # s.jain1_lr = (1 / b) * s.jain1_sto + (b - 1) / b * s.jain1_det # s.jain1_lr = 1 / s.jain1_lr # hess.diag_trace, jac.diag_trace # Version 2 of Jain stochastic rates, use Jacobian squared for curvature s.jain2_sto = s.lyap_jac_max * s.jac_trace / s.lyap_jac_ave s.jain2_det = s.jac_l2 s.jain2_lr = (1 / b) * s.jain2_sto + (b - 1) / b * s.jain2_det s.jain2_lr = 1 / s.jain2_lr u.log_spectrum(f'layer-{i}/hess_A', hess_A) u.log_spectrum(f'layer-{i}/hess_B', hess_B) u.log_spectrum(f'layer-{i}/hess_AB', u.outer(hess_A, hess_B).flatten()) u.log_spectrum(f'layer-{i}/jac_A', jac_A) u.log_spectrum(f'layer-{i}/jac_B', jac_B) u.log_spectrum(f'layer-{i}/fish_A', fish_A) u.log_spectrum(f'layer-{i}/fish_B', fish_B) u.log_scalars({f'layer-{i}/trace_ratio': fish_B.sum()/hess_B.sum()}) L = torch.eig(u.lyapunov_spectral(hess.BB, 2*fish.BB, cond=1e-8))[0] L = L[:, 0] # extract real part L = L.sort()[0] L = torch.flip(L, [0]) L_cheap = torch.eig(fish.BB @ u.pinv(hess.BB, cond=1e-8))[0] L_cheap = L_cheap[:, 0] # extract real part L_cheap = L_cheap.sort()[0] L_cheap = torch.flip(L_cheap, [0]) d = len(hess_B) u.log_spectrum(f'layer-{i}/Lyap', L) u.log_spectrum(f'layer-{i}/Lyap_cheap', L_cheap) u.log_scalars({f'layer-{i}/dims': d}) u.log_scalars({f'layer-{i}/L_erank': erank(L)}) u.log_scalars({f'layer-{i}/L_cheap_erank': erank(L_cheap)}) u.log_scalars({f'layer-{i}/rho': d/erank(L)}) u.log_scalars({f'layer-{i}/rho_cheap': d/erank(L_cheap)}) model.train() with u.timeit('train'): for i in range(args.train_steps): optimizer.zero_grad() data, targets = next(train_iter) model.zero_grad() output = model(data) loss = loss_fn(output, targets) loss.backward() optimizer.step() if args.weight_decay: for group in optimizer.param_groups: for param in group['params']: param.data.mul_(1 - args.weight_decay) gl.token_count += data.shape[0] gl.event_writer.close()