Exemplo n.º 1
0
def plot_histograms(model, x, y, epoch, out_dir):
    """
    Plots some histograms

    :param model: a C-GAE instance
    :param batch: the batch to use for computing histograms
    :param epoch: epoch nr
    :param out_dir: directory where to save the histograms
    """
    plot_hist(to_numpy(x), f"data_ep{epoch}",
              os.path.join(out_dir, f"hist_data_ep{epoch}.png"))

    mapping = to_numpy(model.mapping_(x, y))
    plot_hist(mapping, f"mapping_ep{epoch}",
              os.path.join(out_dir, f"hist_map_ep{epoch}.png"))
    output = to_numpy(model(model.mapping_(x, y), y))
    plot_hist(output, f"recon_ep{epoch}",
              os.path.join(out_dir, f"hist_recon_ep{epoch}.png"))

    weight_x = to_numpy(model.conv_x.weight)
    plot_hist(weight_x, f"weightx_ep{epoch}",
              os.path.join(out_dir, f"hist_weightx_ep{epoch}.png"))
    weight_y = to_numpy(model.conv_y.weight)
    plot_hist(weight_y, f"weighty_ep{epoch}",
              os.path.join(out_dir, f"hist_weighty_ep{epoch}.png"))
Exemplo n.º 2
0
def plot_kernels(model, epoch, out_dir):
    """
    Plots the input weights of the C-GAE

    :param model: a C-GAE instance
    :param epoch: epoch nr (int)
    :param out_dir: directory where to save the plot
    """
    filter_x = to_numpy(model.conv_x.weight)
    make_tiles(filter_x, os.path.join(out_dir, f"filtersx_ep{epoch}.png"))

    filter_y = to_numpy(model.conv_y.weight)
    make_tiles(filter_y, os.path.join(out_dir, f"filtersy_ep{epoch}.png"))
Exemplo n.º 3
0
def evaluate(model, config, valid_loader, eval_device=None):
    opt = config['opt']
    device = opt.device
    if eval_device != None: device = eval_device
    total_loss = 0.
    total_examples = 0
    correct = 0
    criterion = torch.nn.CrossEntropyLoss().to(device)
    preds = None
    ys = None
    with torch.no_grad():
        iterator = tqdm(valid_loader,
                        total=len(valid_loader),
                        desc=f"Evaluate")
        for i, (x, y) in enumerate(iterator):
            model.eval()
            x = to_device(x, device)
            y = to_device(y, device)
            logits = model(x)
            loss = criterion(logits, y)
            # softmax after computing cross entropy loss
            logits = torch.softmax(logits, dim=-1)
            if preds is None:
                preds = to_numpy(logits)
                ys = to_numpy(y)
            else:
                preds = np.append(preds, to_numpy(logits), axis=0)
                ys = np.append(ys, to_numpy(y), axis=0)
            predicted = logits.argmax(1)
            correct += (predicted == y).sum().item()
            cur_examples = y.size(0)
            total_loss += (loss.item() * cur_examples)
            total_examples += cur_examples
    # generate report
    labels = config['labels']
    label_names = [v for k, v in sorted(labels.items(), key=lambda x: x[0])]
    preds_ids = np.argmax(preds, axis=1)
    try:
        print(
            classification_report(ys,
                                  preds_ids,
                                  target_names=label_names,
                                  digits=4))
        print(labels)
        print(confusion_matrix(ys, preds_ids))
    except Exception as e:
        logger.warn(str(e))
    cur_loss = total_loss / total_examples
    cur_acc = correct / total_examples
    return cur_loss, cur_acc
Exemplo n.º 4
0
def evaluate_maskedLogit(data_path, seuil=None):
    print("loading dataset...")
    datas = loaders.load_from_disk(data_path)
    loaders.set_dataset_format(datas)

    print("loading model...")
    model = models.maskedLogit()
    model.eval()
    model.to(device)

    dataloader = torch.utils.data.DataLoader(datas, batch_size=16)

    outputs = [[], []]
    confusion = [[0, 0], [0, 0]]

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="evaluation"):
            out = model(batch['index'].to(device), batch['token'].to(device),
                        batch['input_ids'].to(device),
                        batch['attention_mask'].to(device))
            for l, o in zip(to_numpy(batch['label']), out):
                outputs[l].append(o.item())
                if seuil is not None:
                    confusion[l][o < seuil] += 1
    return outputs, confusion
Exemplo n.º 5
0
    def train(self, batch, batch_size, gamma, time_stamp):
        self.net.train()
        self.step_cnt += 1
        assert (len(batch) == batch_size)
        if len(batch) < batch_size:
            return None

        states = np.vstack([x.state for x in batch])
        actions = np.vstack([x.action for x in batch])
        rewards = np.vstack([x.reward for x in batch])
        next_states = np.vstack([x.next_state for x in batch])
        done = np.vstack([x.done for x in batch])
        actions = to_torch_var(actions).long()

        Q_pred = self.get_Q(states).gather(1, actions)
        done_mask = (~done).astype(np.float)
        Q_expected = np.max(to_numpy(self.get_actor_Q(next_states)), axis=1)
        Q_expected = np.expand_dims(Q_expected, axis=1)
        Q_expected = done_mask * Q_expected
        Q_target = rewards + gamma * Q_expected
        Q_target = to_torch_var(Q_target)

        self.optimizer.zero_grad()
        loss = self.lossfn(Q_pred, Q_target)
        loss.backward()
        for param in self.net.parameters():
            param.grad.data.clamp_(-1, 1)
        self.optimizer.step()
        if self.step_cnt % self.update_rate == 0:
            self.actor_net.load_state_dict(self.net.state_dict())
        return loss
 def train_one_step(self, data_i, total_steps_so_far):
     images_minibatch = self.prepare_images(data_i)
     if self.toggle_training_mode() == "generator":
         losses = self.train_discriminator_one_step(images_minibatch)
     else:
         losses = self.train_generator_one_step(images_minibatch)
     return util.to_numpy(losses)
Exemplo n.º 7
0
    def testFit(self):
        self.unmixModel = AnalysisBySynthesis(paramFile=self.modelFilename, wavenumbers=self.wavenumbers,
                                      dtype=torch.float64, device=device)
        trueAbundances = torch.empty(len(self.unmixModel.endmemberModels), dtype=torch.float64, device=device).uniform_(0, 1)
        trueAbundances /= trueAbundances.sum()
        trueSpectra = torch.zeros(len(self.wavenumbers), dtype=torch.float64, device=device)
        for i, model in enumerate(self.unmixModel.endmemberModels.values()):
            for parameter in model.parameters():
                parameter.requires_grad_(False)
                perturb = np.random.uniform(-.0001, 0.0001)
                parameter.add_(perturb)
            trueSpectra += model.forward()*trueAbundances[i]

        self.unmixModel = AnalysisBySynthesis(paramFile=self.modelFilename, wavenumbers=self.wavenumbers,
                                      dtype=torch.float64, device=device)
        trueSpectra = to_numpy(trueSpectra)
        trueSpectra = torch.from_numpy(trueSpectra).type(torch.float64).to(device)
        self.unmixModel.fit(trueSpectra, epochs=100, learningRate=1e-6)
        self.assertTrue(torch.allclose(trueSpectra, self.unmixModel.predictedSpectra, rtol=5e-2))
        self.assertTrue(torch.allclose(trueAbundances, self.unmixModel.abundances, rtol=1e-1))

        plt.plot(self.wavenumbers.cpu().detach().numpy(),
                 trueSpectra.cpu().detach().numpy(),
                 label='truth')
        plt.plot(self.wavenumbers.cpu().detach().numpy(),
                 self.unmixModel.predictedSpectra.cpu().detach().numpy(),
                 '--',
                 label='model')
        plt.title('Test Fit')
        plt.legend()
        plt.show()
Exemplo n.º 8
0
    def run_extracter(src):
        indexes = src['index'].to(device)
        in_ids = src['input_ids'].to(device)
        att_masks = src['attention_mask'].to(device)

        with torch.no_grad():
            hidden_states = extracter(indexes, in_ids, att_masks)
        return {'hidden_state': to_numpy(hidden_states)}
Exemplo n.º 9
0
def plot_data(data, epoch, out_dir):
    """
    Plots input data

    :param data: a data batch to plot
    :param epoch: epoch nr
    :param out_dir: directory where to save the plot
    """
    make_tiles(to_numpy(data), os.path.join(out_dir, f"data_ep{epoch}.png"))
Exemplo n.º 10
0
def build_onnx_input(config, ort_session, x):
    args = config['args']
    x = to_numpy(x)
    if config['emb_class'] in ['glove', 'elmo']:
        ort_inputs = {
            ort_session.get_inputs()[0].name: x[0],
            ort_session.get_inputs()[1].name: x[1]
        }
        if args.use_char_cnn:
            ort_inputs[ort_session.get_inputs()[2].name] = x[2]
    else:
        # x order must be sync with x parameter of BertLSTMCRF.forward().
        # x[0,1,2] : [batch_size, seq_size], input_ids / input_mask / segment_ids == input_ids / attention_mask / token_type_ids
        # x[3] :     [batch_size, seq_size], pos_ids
        # x[4] :     [batch_size, seq_size, char_n_ctx], char_ids

        # with --bert_use_doc_context
        # x[5] :     [batch_size, seq_size], doc2sent_idx
        # x[6] :     [batch_size, seq_size], doc2sent_mask
        # x[7] :     [batch_size, seq_size], word2token_idx  with --bert_use_subword_pooling
        # x[8] :     [batch_size, seq_size], word2token_mask with --bert_use_subword_pooling
        # x[9] :     [batch_size, seq_size], word_ids        with --bert_use_word_embedding

        # without --bert_use_doc_context
        # x[5] :     [batch_size, seq_size], word2token_idx  with --bert_use_subword_pooling
        # x[6] :     [batch_size, seq_size], word2token_mask with --bert_use_subword_pooling
        # x[7] :     [batch_size, seq_size], word_ids        with --bert_use_word_embedding
        if config['emb_class'] in ['roberta', 'distilbert', 'bart']:
            ort_inputs = {
                ort_session.get_inputs()[0].name: x[0],
                ort_session.get_inputs()[1].name: x[1]
            }
        else:
            ort_inputs = {
                ort_session.get_inputs()[0].name: x[0],
                ort_session.get_inputs()[1].name: x[1],
                ort_session.get_inputs()[2].name: x[2]
            }
        if args.bert_use_pos:
            ort_inputs[ort_session.get_inputs()[3].name] = x[3]
        if args.use_char_cnn:
            ort_inputs[ort_session.get_inputs()[4].name] = x[4]
        base_idx = 5
        if args.bert_use_doc_context:
            ort_inputs[ort_session.get_inputs()[base_idx].name] = x[base_idx]
            ort_inputs[ort_session.get_inputs()[base_idx +
                                                1].name] = x[base_idx + 1]
            base_idx += 2
        if args.bert_use_subword_pooling:
            ort_inputs[ort_session.get_inputs()[base_idx].name] = x[base_idx]
            ort_inputs[ort_session.get_inputs()[base_idx +
                                                1].name] = x[base_idx + 1]
            if args.bert_use_word_embedding:
                ort_inputs[ort_session.get_inputs()[base_idx +
                                                    2].name] = x[base_idx + 2]
    return ort_inputs
Exemplo n.º 11
0
    def set_to_norm(self, val):
        """
        Sets the norms of all convolutional kernels of the C-GAE to a specific
        value.

        :param val: norms of kernels are set to this value
        """
        shape_x = self.conv_x.weight.size()
        conv_x_reshape = self.conv_x.weight.view(shape_x[0], -1)
        norms_x = ((conv_x_reshape**2).sum(1)**.5).view(-1, 1)
        conv_x_reshape = conv_x_reshape / norms_x
        weight_x_new = to_numpy(conv_x_reshape.view(*shape_x)) * val
        self.conv_x.weight.data = cuda_tensor(weight_x_new)

        shape_y = self.conv_y.weight.size()
        conv_y_reshape = self.conv_y.weight.view(shape_y[0], -1)
        norms_y = ((conv_y_reshape**2).sum(1)**.5).view(-1, 1)
        conv_y_reshape = conv_y_reshape / norms_y
        weight_y_new = to_numpy(conv_y_reshape.view(*shape_y)) * val
        self.conv_y.weight.data = cuda_tensor(weight_y_new)
Exemplo n.º 12
0
def plot_recon(model, x, y, epoch, out_dir):
    """
    Plots the reconstruction of an input batch

    :param model: a C-GAE instance
    :param batch: the batch to reconstruct
    :param epoch: epoch nr
    :param out_dir: directory where to save the plot
    """
    output = to_numpy(model(model.mapping_(x, y), y))
    make_tiles(output, os.path.join(out_dir, f"recon_ep{epoch}.png"))
Exemplo n.º 13
0
def evaluate_classifier(datas, model):
    dataloader = torch.utils.data.DataLoader(datas, batch_size=8)
    confusion = [[0, 0], [0, 0]]
    loss = 0
    num_sample = 0

    criterion = torch.nn.BCEWithLogitsLoss(reduction='sum')

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="evaluation", leave=None):
            pred = model(batch['token'].to(device),
                         batch['hidden_state'].to(torch.float).to(device))
            loss += criterion(pred, batch['label'].to(
                torch.float).to(device)).item()
            num_sample += len(batch['token'])

            pred = to_numpy(torch.sigmoid(pred))
            for p, l in zip(pred > 0.5, to_numpy(batch['label'])):
                confusion[l][p] += 1

    acc = confusion[0][0] + confusion[1][1]
    return loss / num_sample, acc / num_sample, confusion
Exemplo n.º 14
0
    def inference(self, data):
        config = self.config
        opt = config['opt']
        model = self.model
        labels = self.labels

        logits = model(data)
        predicted = logits.argmax(1)
        predicted = to_numpy(predicted)[0]
        predicted_raw = labels[predicted]
        logger.info("[Model predicted] %s", predicted_raw)

        return predicted_raw
Exemplo n.º 15
0
def evaluate(model, config, val_loader):
    model.eval()
    opt = config['opt']
    pad_label_id = config['pad_label_id']

    eval_loss = 0.
    criterion = nn.CrossEntropyLoss(ignore_index=pad_label_id).to(opt.device)
    n_batches = len(val_loader)
    prog = Progbar(target=n_batches)
    preds = None
    ys = None
    with torch.no_grad():
        for i, (x, y) in enumerate(val_loader):
            x = to_device(x, opt.device)
            y = to_device(y, opt.device)
            if opt.use_crf:
                logits, prediction = model(x)
                mask = torch.sign(torch.abs(x[0])).to(torch.uint8).to(
                    opt.device)
                log_likelihood = model.crf(logits,
                                           y,
                                           mask=mask,
                                           reduction='mean')
                loss = -1 * log_likelihood
            else:
                logits = model(x)
                loss = criterion(logits.view(-1, model.label_size), y.view(-1))
            if preds is None:
                if opt.use_crf: preds = to_numpy(prediction)
                else: preds = to_numpy(logits)
                ys = to_numpy(y)
            else:
                if opt.use_crf:
                    preds = np.append(preds, to_numpy(prediction), axis=0)
                else:
                    preds = np.append(preds, to_numpy(logits), axis=0)
                ys = np.append(ys, to_numpy(y), axis=0)
            eval_loss += loss.item()
            prog.update(i + 1, [('eval curr loss', loss.item())])
    eval_loss = eval_loss / n_batches
    if not opt.use_crf: preds = np.argmax(preds, axis=2)
    # compute measure using seqeval
    labels = model.labels
    ys_lbs = [[] for _ in range(ys.shape[0])]
    preds_lbs = [[] for _ in range(ys.shape[0])]
    for i in range(ys.shape[0]):  # foreach sentence
        for j in range(ys.shape[1]):  # foreach token
            if ys[i][j] != pad_label_id:
                ys_lbs[i].append(labels[ys[i][j]])
                preds_lbs[i].append(labels[preds[i][j]])
    ret = {
        "loss": eval_loss,
        "precision": precision_score(ys_lbs, preds_lbs),
        "recall": recall_score(ys_lbs, preds_lbs),
        "f1": f1_score(ys_lbs, preds_lbs),
        "report": classification_report(ys_lbs, preds_lbs, digits=4),
    }
    print(ret['report'])
    return ret
Exemplo n.º 16
0
def plot_mapping(model, x, y, epoch, out_dir):
    """
    Plots the top most mapping layer given an input batch

    :param model: a C-GAE instance
    :param batch: the batch to reconstruct
    :param epoch: epoch nr
    :param out_dir: directory where to save the plot
    """
    # x = cuda_variable(batch[0][None,:,:,:])
    # y = cuda_variable(batch[1][None,:,:,:])
    output = np.transpose(to_numpy(model.mapping_(x, y)), axes=(0, 3, 2, 1))
    make_tiles(output, os.path.join(out_dir, f"mapping_ep{epoch}.png"))
Exemplo n.º 17
0
def atest_pinv():
    a = torch.tensor([[2., 7, 9], [1, 9, 8], [2, 7, 5]])
    b = torch.tensor([[6., 6, 1], [10, 7, 7], [7, 10, 10]])
    C = u.Kron(a, b)
    u.check_close(a.flatten().norm() * b.flatten().norm(), C.frobenius_norm())

    u.check_close(C.frobenius_norm(), 4 * math.sqrt(11635.))

    Ci = [[
        0, 5 / 102, -(7 / 204), 0, -(70 / 561), 49 / 561, 0, 125 / 1122,
        -(175 / 2244)
    ],
          [
              1 / 20, -(53 / 1020), 8 / 255, -(7 / 55), 371 / 2805,
              -(224 / 2805), 5 / 44, -(265 / 2244), 40 / 561
          ],
          [
              -(1 / 20), 3 / 170, 3 / 170, 7 / 55, -(42 / 935), -(42 / 935),
              -(5 / 44), 15 / 374, 15 / 374
          ],
          [
              0, -(5 / 102), 7 / 204, 0, 20 / 561, -(14 / 561), 0, 35 / 1122,
              -(49 / 2244)
          ],
          [
              -(1 / 20), 53 / 1020, -(8 / 255), 2 / 55, -(106 / 2805),
              64 / 2805, 7 / 220, -(371 / 11220), 56 / 2805
          ],
          [
              1 / 20, -(3 / 170), -(3 / 170), -(2 / 55), 12 / 935, 12 / 935,
              -(7 / 220), 21 / 1870, 21 / 1870
          ], [0, 5 / 102, -(7 / 204), 0, 0, 0, 0, -(5 / 102), 7 / 204],
          [
              1 / 20, -(53 / 1020), 8 / 255, 0, 0, 0, -(1 / 20), 53 / 1020,
              -(8 / 255)
          ],
          [
              -(1 / 20), 3 / 170, 3 / 170, 0, 0, 0, 1 / 20, -(3 / 170),
              -(3 / 170)
          ]]
    C = C.expand_vec()
    C0 = u.to_numpy(C)
    Ci = torch.tensor(Ci)
    u.check_close(C @ Ci @ C, C)

    u.check_close(linalg.pinv(C0), Ci, rtol=1e-5, atol=1e-6)
    u.check_close(torch.pinverse(C), Ci, rtol=1e-5, atol=1e-6)
    u.check_close(u.pinv(C), Ci, rtol=1e-5, atol=1e-6)
    u.check_close(C.pinv(), Ci, rtol=1e-5, atol=1e-6)
Exemplo n.º 18
0
def build_onnx_input(config, ort_session, x):
    args = config['args']
    x = to_numpy(x)
    if config['emb_class'] == 'glove':
        ort_inputs = {ort_session.get_inputs()[0].name: x}
    else:
        if config['emb_class'] in ['roberta', 'distilbert', 'bart']:
            ort_inputs = {
                ort_session.get_inputs()[0].name: x[0],
                ort_session.get_inputs()[1].name: x[1]
            }
        else:
            ort_inputs = {
                ort_session.get_inputs()[0].name: x[0],
                ort_session.get_inputs()[1].name: x[1],
                ort_session.get_inputs()[2].name: x[2]
            }
    return ort_inputs
Exemplo n.º 19
0
def inference(opt):
    # set config
    config = load_config(opt)
    if opt.num_threads > 0: torch.set_num_threads(opt.num_threads)
    config['opt'] = opt

    # set path: opt.embedding_path, opt.vocab_path, opt.label_path
    set_path(config)

    # load pytorch model checkpoint
    checkpoint = load_checkpoint(opt.model_path, device=opt.device)

    # prepare model and load parameters
    model = load_model(config, checkpoint)
    model.eval()

    # load onnx model for using onnxruntime
    if opt.enable_ort:
        import onnxruntime as ort
        sess_options = ort.SessionOptions()
        sess_options.inter_op_num_threads = opt.num_threads
        sess_options.intra_op_num_threads = opt.num_threads
        ort_session = ort.InferenceSession(opt.onnx_path,
                                           sess_options=sess_options)

    # enable to use dynamic quantized model (pytorch>=1.3.0)
    if opt.enable_dqm and opt.device == 'cpu':
        model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear},
                                                    dtype=torch.qint8)
        print(model)

    # prepare tokenizer
    tokenizer = prepare_tokenizer(config, model)

    # prepare labels
    labels = config['labels']

    # inference
    f_out = open(opt.test_path + '.inference', 'w', encoding='utf-8')
    total_examples = 0
    total_duration_time = 0.0
    with torch.no_grad(), open(opt.test_path, 'r', encoding='utf-8') as f:
        for i, line in enumerate(f):
            start_time = time.time()
            sent, label = line.strip().split('\t')
            x_raw = sent.split()
            y_raw = label
            text = ' '.join(x_raw)
            x = encode_text(config, tokenizer, text)
            x = to_device(x, opt.device)

            if opt.enable_ort:
                x = to_numpy(x)
                if config['emb_class'] == 'glove':
                    ort_inputs = {ort_session.get_inputs()[0].name: x}
                else:
                    if config['emb_class'] in [
                            'roberta', 'distilbert', 'bart'
                    ]:
                        ort_inputs = {
                            ort_session.get_inputs()[0].name: x[0],
                            ort_session.get_inputs()[1].name: x[1]
                        }
                    else:
                        ort_inputs = {
                            ort_session.get_inputs()[0].name: x[0],
                            ort_session.get_inputs()[1].name: x[1],
                            ort_session.get_inputs()[2].name: x[2]
                        }
                logits = ort_session.run(None, ort_inputs)[0]
                logits = to_device(torch.tensor(logits), opt.device)
            else:
                logits = model(x)

            predicted = logits.argmax(1)
            predicted = to_numpy(predicted)[0]
            predicted_raw = labels[predicted]
            f_out.write(text + '\t' + y_raw + '\t' + predicted_raw + '\n')
            total_examples += 1
            if opt.num_examples != 0 and total_examples >= opt.num_examples:
                logger.info("[Stop Inference] : up to the {} examples".format(
                    total_examples))
                break
            duration_time = float((time.time() - start_time) * 1000)
            if i != 0: total_duration_time += duration_time
            logger.info("[Elapsed Time] : {}ms".format(duration_time))
    f_out.close()
    logger.info(
        "[Elapsed Time(total_duration_time, average)] : {}ms, {}ms".format(
            total_duration_time, total_duration_time / (total_examples - 1)))
Exemplo n.º 20
0
def evaluate(opt):
    # set config
    config = load_config(opt)
    if opt.num_threads > 0: torch.set_num_threads(opt.num_threads)
    config['opt'] = opt
    logger.info("%s", config)

    # set path
    set_path(config)

    # prepare test dataset
    test_loader = prepare_datasets(config)

    # load pytorch model checkpoint
    checkpoint = load_checkpoint(opt.model_path, device=opt.device)

    # prepare model and load parameters
    model = load_model(config, checkpoint)
    model.eval()

    # convert to onnx
    if opt.convert_onnx:
        (x, y) = next(iter(test_loader))
        x = to_device(x, opt.device)
        y = to_device(y, opt.device)
        convert_onnx(config, model, x)
        check_onnx(config)
        logger.info("[ONNX model saved] :{}".format(opt.onnx_path))
        # quantize onnx
        if opt.quantize_onnx:
            quantize_onnx(opt.onnx_path, opt.quantized_onnx_path)
            logger.info("[Quantized ONNX model saved] : {}".format(
                opt.quantized_onnx_path))
        return

    # load onnx model for using onnxruntime
    if opt.enable_ort:
        import onnxruntime as ort
        sess_options = ort.SessionOptions()
        sess_options.inter_op_num_threads = opt.num_threads
        sess_options.intra_op_num_threads = opt.num_threads
        ort_session = ort.InferenceSession(opt.onnx_path,
                                           sess_options=sess_options)

    # convert to tvm format
    if opt.convert_tvm:
        (x, y) = next(iter(test_loader))
        x = to_device(x, opt.device)
        y = to_device(y, opt.device)
        convert_tvm(config, model, x)
        logger.info("[TVM model saved] : {}".format(opt.tvm_path))
        return

    # enable to use dynamic quantized model (pytorch>=1.3.0)
    if opt.enable_dqm and opt.device == 'cpu':
        model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear},
                                                    dtype=torch.qint8)
        print(model)

    # evaluation
    preds = None
    ys = None
    correct = 0
    n_batches = len(test_loader)
    total_examples = 0
    whole_st_time = time.time()
    first_time = time.time()
    first_examples = 0
    total_duration_time = 0.0
    with torch.no_grad():
        for i, (x, y) in enumerate(tqdm(test_loader, total=n_batches)):
            start_time = time.time()
            x = to_device(x, opt.device)
            y = to_device(y, opt.device)

            if opt.enable_ort:
                x = to_numpy(x)
                if config['emb_class'] == 'glove':
                    ort_inputs = {ort_session.get_inputs()[0].name: x}
                else:
                    if config['emb_class'] in [
                            'roberta', 'distilbert', 'bart'
                    ]:
                        ort_inputs = {
                            ort_session.get_inputs()[0].name: x[0],
                            ort_session.get_inputs()[1].name: x[1]
                        }
                    else:
                        ort_inputs = {
                            ort_session.get_inputs()[0].name: x[0],
                            ort_session.get_inputs()[1].name: x[1],
                            ort_session.get_inputs()[2].name: x[2]
                        }
                logits = ort_session.run(None, ort_inputs)[0]
                logits = to_device(torch.tensor(logits), opt.device)
            else:
                logits = model(x)

            if preds is None:
                preds = to_numpy(logits)
                ys = to_numpy(y)
            else:
                preds = np.append(preds, to_numpy(logits), axis=0)
                ys = np.append(ys, to_numpy(y), axis=0)
            predicted = logits.argmax(1)
            correct += (predicted == y).sum().item()
            cur_examples = y.size(0)
            total_examples += cur_examples
            if i == 0:  # first one may take longer time, so ignore in computing duration.
                first_time = float((time.time() - first_time) * 1000)
                first_examples = cur_examples
            if opt.num_examples != 0 and total_examples >= opt.num_examples:
                logger.info("[Stop Evaluation] : up to the {} examples".format(
                    total_examples))
                break
            duration_time = float((time.time() - start_time) * 1000)
            if i != 0: total_duration_time += duration_time
            '''
            logger.info("[Elapsed Time] : {}ms".format(duration_time))
            '''
    # generate report
    labels = config['labels']
    label_names = [v for k, v in sorted(labels.items(), key=lambda x: x[0])]
    preds_ids = np.argmax(preds, axis=1)
    try:
        print(
            classification_report(ys,
                                  preds_ids,
                                  target_names=label_names,
                                  digits=4))
        print(labels)
        print(confusion_matrix(ys, preds_ids))
    except Exception as e:
        logger.warn(str(e))

    acc = correct / total_examples
    whole_time = float((time.time() - whole_st_time) * 1000)
    avg_time = (whole_time - first_time) / (total_examples - first_examples)
    # write predictions to file
    write_prediction(opt, preds, labels)
    logger.info("[Accuracy] : {:.4f}, {:5d}/{:5d}".format(
        acc, correct, total_examples))
    logger.info("[Elapsed Time] : {}ms, {}ms on average".format(
        whole_time, avg_time))
    logger.info(
        "[Elapsed Time(total_duration_time, average)] : {}ms, {}ms".format(
            total_duration_time, total_duration_time / (total_examples - 1)))
Exemplo n.º 21
0
def evaluate(opt):
    # set config
    config = load_config(opt)
    if opt.num_threads > 0: torch.set_num_threads(opt.num_threads)
    config['opt'] = opt
    logger.info("%s", config)

    # set path
    set_path(config)

    # prepare test dataset
    test_loader = prepare_datasets(config)

    # load pytorch model checkpoint
    checkpoint = load_checkpoint(config)

    # prepare model and load parameters
    model = load_model(config, checkpoint)
    model.eval()

    # convert to onnx format
    if opt.convert_onnx:
        (x, y) = next(iter(test_loader))
        x = to_device(x, opt.device)
        y = to_device(y, opt.device)
        convert_onnx(config, model, x)
        check_onnx(config)
        logger.info("[ONNX model saved at {}".format(opt.onnx_path))
        # quantize onnx
        if opt.quantize_onnx:
            quantize_onnx(opt.onnx_path, opt.quantized_onnx_path)
            logger.info("[Quantized ONNX model saved at {}".format(
                opt.quantized_onnx_path))
        return

    # load onnx model for using onnxruntime
    if opt.enable_ort:
        import onnxruntime as ort
        sess_options = ort.SessionOptions()
        sess_options.inter_op_num_threads = opt.num_threads
        sess_options.intra_op_num_threads = opt.num_threads
        ort_session = ort.InferenceSession(opt.onnx_path,
                                           sess_options=sess_options)

    # enable to use dynamic quantized model (pytorch>=1.3.0)
    if opt.enable_dqm and opt.device == 'cpu':
        model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear},
                                                    dtype=torch.qint8)
        print(model)

    # evaluation
    preds = None
    ys = None
    n_batches = len(test_loader)
    total_examples = 0
    whole_st_time = time.time()
    first_time = time.time()
    first_examples = 0
    total_duration_time = 0.0
    with torch.no_grad():
        for i, (x, y) in enumerate(tqdm(test_loader, total=n_batches)):
            start_time = time.time()
            x = to_device(x, opt.device)
            y = to_device(y, opt.device)

            if opt.enable_ort:
                x = to_numpy(x)
                if config['emb_class'] == 'glove':
                    ort_inputs = {
                        ort_session.get_inputs()[0].name: x[0],
                        ort_session.get_inputs()[1].name: x[1]
                    }
                    if opt.use_char_cnn:
                        ort_inputs[ort_session.get_inputs()[2].name] = x[2]
                if config['emb_class'] in [
                        'bert', 'distilbert', 'albert', 'roberta', 'bart',
                        'electra'
                ]:
                    if config['emb_class'] in ['distilbert', 'bart']:
                        ort_inputs = {
                            ort_session.get_inputs()[0].name: x[0],
                            ort_session.get_inputs()[1].name: x[1]
                        }
                    else:
                        ort_inputs = {
                            ort_session.get_inputs()[0].name: x[0],
                            ort_session.get_inputs()[1].name: x[1],
                            ort_session.get_inputs()[2].name: x[2]
                        }
                    if opt.bert_use_pos:
                        ort_inputs[ort_session.get_inputs()[3].name] = x[3]
                if opt.use_crf:
                    logits, prediction = ort_session.run(None, ort_inputs)
                    prediction = to_device(torch.tensor(prediction),
                                           opt.device)
                    logits = to_device(torch.tensor(logits), opt.device)
                else:
                    logits = ort_session.run(None, ort_inputs)[0]
                    logits = to_device(torch.tensor(logits), opt.device)
            else:
                if opt.use_crf: logits, prediction = model(x)
                else: logits = model(x)

            if preds is None:
                if opt.use_crf: preds = to_numpy(prediction)
                else: preds = to_numpy(logits)
                ys = to_numpy(y)
            else:
                if opt.use_crf:
                    preds = np.append(preds, to_numpy(prediction), axis=0)
                else:
                    preds = np.append(preds, to_numpy(logits), axis=0)
                ys = np.append(ys, to_numpy(y), axis=0)
            cur_examples = y.size(0)
            total_examples += cur_examples
            if i == 0:  # first one may take longer time, so ignore in computing duration.
                first_time = float((time.time() - first_time) * 1000)
                first_examples = cur_examples
            if opt.num_examples != 0 and total_examples >= opt.num_examples:
                logger.info("[Stop Evaluation] : up to the {} examples".format(
                    total_examples))
                break
            duration_time = float((time.time() - start_time) * 1000)
            if i != 0: total_duration_time += duration_time
            '''
            logger.info("[Elapsed Time] : {}ms".format(duration_time))
            '''
    whole_time = float((time.time() - whole_st_time) * 1000)
    avg_time = (whole_time - first_time) / (total_examples - first_examples)
    if not opt.use_crf: preds = np.argmax(preds, axis=2)
    # compute measure using seqeval
    labels = model.labels
    ys_lbs = [[] for _ in range(ys.shape[0])]
    preds_lbs = [[] for _ in range(ys.shape[0])]
    pad_label_id = config['pad_label_id']
    for i in range(ys.shape[0]):  # foreach sentence
        for j in range(ys.shape[1]):  # foreach token
            if ys[i][j] != pad_label_id:
                ys_lbs[i].append(labels[ys[i][j]])
                preds_lbs[i].append(labels[preds[i][j]])
    ret = {
        "precision": precision_score(ys_lbs, preds_lbs),
        "recall": recall_score(ys_lbs, preds_lbs),
        "f1": f1_score(ys_lbs, preds_lbs),
        "report": classification_report(ys_lbs, preds_lbs, digits=4),
    }
    print(ret['report'])
    f1 = ret['f1']
    # write predicted labels to file
    default_label = config['default_label']
    write_prediction(opt, ys, preds, labels, pad_label_id, default_label)

    logger.info("[F1] : {}, {}".format(f1, total_examples))
    logger.info("[Elapsed Time] : {} examples, {}ms, {}ms on average".format(
        total_examples, whole_time, avg_time))
    logger.info(
        "[Elapsed Time(total_duration_time, average)] : {}ms, {}ms".format(
            total_duration_time, total_duration_time / (total_examples - 1)))
Exemplo n.º 22
0
 def get_grads(self):
     """
     Returns the current gradient
     """
     return deepcopy(
         np.hstack([to_numpy(v.grad).flatten() for v in self.parameters()]))
Exemplo n.º 23
0
 def get_params(self):
     """
     Returns parameters of the actor
     """
     return deepcopy(
         np.hstack([to_numpy(v).flatten() for v in self.parameters()]))
Exemplo n.º 24
0
def evaluate(args):
    # set config
    config = load_config(args)
    if args.num_threads > 0: torch.set_num_threads(args.num_threads)
    config['args'] = args
    logger.info("%s", config)

    # set path
    set_path(config)

    # prepare test dataset
    test_loader = prepare_datasets(config)

    # load pytorch model checkpoint
    checkpoint = load_checkpoint(args.model_path, device=args.device)

    # prepare model and load parameters
    model = load_model(config, checkpoint)
    model.eval()

    # convert to onnx format
    if args.convert_onnx:
        # FIXME not working for --use_crf
        batch = next(iter(test_loader))
        if config['emb_class'] not in ['glove', 'elmo']:
            x, y, gy = batch
        else:
            x, y = batch
        x = to_device(x, args.device)
        convert_onnx(config, model, x)
        check_onnx(config)
        logger.info("[ONNX model saved] : {}".format(args.onnx_path))
        # quantize onnx
        if args.quantize_onnx:
            quantize_onnx(args.onnx_path, args.quantized_onnx_path)
            logger.info("[Quantized ONNX model saved] : {}".format(
                args.quantized_onnx_path))
        return

    # load onnx model for using onnxruntime
    if args.enable_ort:
        import onnxruntime as ort
        sess_options = ort.SessionOptions()
        sess_options.inter_op_num_threads = args.num_threads
        sess_options.intra_op_num_threads = args.num_threads
        ort_session = ort.InferenceSession(args.onnx_path,
                                           sess_options=sess_options)

    # enable to use dynamic quantized model (pytorch>=1.3.0)
    if args.enable_dqm:
        model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear},
                                                    dtype=torch.qint8)
        print(model)

    # evaluation
    preds = None
    ys = None
    gpreds = None
    gys = None
    n_batches = len(test_loader)
    total_examples = 0
    whole_st_time = time.time()
    first_time = time.time()
    first_examples = 0
    total_duration_time = 0.0
    with torch.no_grad():
        for i, batch in enumerate(tqdm(test_loader, total=n_batches)):
            start_time = time.time()
            if config['emb_class'] not in ['glove', 'elmo']:
                x, y, gy = batch
                gy = to_device(gy, args.device)
            else:
                x, y = batch

            x = to_device(x, args.device)
            y = to_device(y, args.device)
            if args.enable_ort:
                ort_inputs = build_onnx_input(config, ort_session, x)
                if args.use_crf:
                    # FIXME not working for --use_crf
                    if args.bert_use_mtl:
                        logits, prediction, glogits = ort_session.run(
                            None, ort_inputs)
                        glogits = to_device(torch.tensor(glogits), args.device)
                    else:
                        logits, prediction = ort_session.run(None, ort_inputs)
                    prediction = to_device(torch.tensor(prediction),
                                           args.device)
                    logits = to_device(torch.tensor(logits), args.device)
                else:
                    if args.bert_use_mtl:
                        logits, glogits = ort_session.run(None, ort_inputs)
                        glogits = to_device(torch.tensor(glogits), args.device)
                    else:
                        logits = ort_session.run(None, ort_inputs)[0]
                    logits = to_device(torch.tensor(logits), args.device)
                    logits = torch.softmax(logits, dim=-1)
            else:
                if args.use_crf:
                    if args.bert_use_mtl:
                        logits, prediction, glogits = model(x)
                    else:
                        logits, prediction = model(x)
                else:
                    if args.bert_use_mtl:
                        logits, glogits = model(x)
                    else:
                        logits = model(x)
                    logits = torch.softmax(logits, dim=-1)

            if preds is None:
                if args.use_crf: preds = to_numpy(prediction)
                else: preds = to_numpy(logits)
                ys = to_numpy(y)
            else:
                if args.use_crf:
                    preds = np.append(preds, to_numpy(prediction), axis=0)
                else:
                    preds = np.append(preds, to_numpy(logits), axis=0)
                ys = np.append(ys, to_numpy(y), axis=0)

            if args.bert_use_mtl:
                glogits = torch.softmax(glogits, dim=-1)
                if gpreds is None:
                    gpreds = to_numpy(glogits)
                    gys = to_numpy(gy)
                else:
                    gpreds = np.append(gpreds, to_numpy(glogits), axis=0)
                    gys = np.append(gys, to_numpy(gy), axis=0)

            cur_examples = y.size(0)
            total_examples += cur_examples
            if i == 0:  # first one may take longer time, so ignore in computing duration.
                first_time = float((time.time() - first_time) * 1000)
                first_examples = cur_examples
            if args.num_examples != 0 and total_examples >= args.num_examples:
                logger.info("[Stop Evaluation] : up to the {} examples".format(
                    total_examples))
                break
            duration_time = float((time.time() - start_time) * 1000)
            if i != 0: total_duration_time += duration_time
            '''
            logger.info("[Elapsed Time] : {}ms".format(duration_time))
            '''
    whole_time = float((time.time() - whole_st_time) * 1000)
    avg_time = (whole_time - first_time) / (total_examples - first_examples)

    # generate report for token classification
    if not args.use_crf: preds = np.argmax(preds, axis=2)
    # compute measure using seqeval
    labels = config['labels']
    ys_lbs = [[] for _ in range(ys.shape[0])]
    preds_lbs = [[] for _ in range(ys.shape[0])]
    pad_label_id = config['pad_label_id']
    for i in range(ys.shape[0]):  # foreach sentence
        for j in range(ys.shape[1]):  # foreach token
            if ys[i][j] != pad_label_id:
                ys_lbs[i].append(labels[ys[i][j]])
                preds_lbs[i].append(labels[preds[i][j]])
    ret = {
        "precision": precision_score(ys_lbs, preds_lbs),
        "recall": recall_score(ys_lbs, preds_lbs),
        "f1": f1_score(ys_lbs, preds_lbs),
        "report": classification_report(ys_lbs, preds_lbs, digits=4),
    }
    print(ret['report'])
    # write predicted labels to file
    write_prediction(config, model, ys, preds, labels)

    # generate report for sequence classification
    if args.bert_use_mtl:
        glabels = config['glabels']
        glabel_names = [
            v for k, v in sorted(glabels.items(), key=lambda x: x[0])
        ]
        glabel_ids = [k for k in glabels.keys()]
        gpreds_ids = np.argmax(gpreds, axis=1)
        try:
            g_report = sequence_classification_report(
                gys,
                gpreds_ids,
                target_names=glabel_names,
                labels=glabel_ids,
                digits=4)
            g_report_dict = sequence_classification_report(
                gys,
                gpreds_ids,
                target_names=glabel_names,
                labels=glabel_ids,
                output_dict=True)
            g_matrix = confusion_matrix(gys, gpreds_ids)
            ret['g_report'] = g_report
            ret['g_report_dict'] = g_report_dict
            ret['g_f1'] = g_report_dict['micro avg']['f1-score']
            ret['g_matrix'] = g_matrix
        except Exception as e:
            logger.warn(str(e))
        print(ret['g_report'])
        print(ret['g_f1'])
        print(ret['g_matrix'])
        logger.info("[sequence classification F1] : {}, {}".format(
            ret['g_f1'], total_examples))
        # write predicted glabels to file
        write_gprediction(args, gpreds, glabels)

    logger.info("[token classification F1] : {}, {}".format(
        ret['f1'], total_examples))
    logger.info("[Elapsed Time] : {} examples, {}ms, {}ms on average".format(
        total_examples, whole_time, avg_time))
    logger.info(
        "[Elapsed Time(total_duration_time, average)] : {}ms, {}ms".format(
            total_duration_time, total_duration_time / (total_examples - 1)))
Exemplo n.º 25
0
def test_kron():
    """Test kron, vec and vecr identities"""
    torch.set_default_dtype(torch.float64)
    a = torch.tensor([1, 2, 3, 4]).reshape(2, 2)
    b = torch.tensor([5, 6, 7, 8]).reshape(2, 2)
    u.check_close(u.Kron(a, b).trace(), 65)

    a = torch.tensor([[2., 7, 9], [1, 9, 8], [2, 7, 5]])
    b = torch.tensor([[6., 6, 1], [10, 7, 7], [7, 10, 10]])
    Ck = u.Kron(a, b)
    u.check_close(a.flatten().norm() * b.flatten().norm(), Ck.frobenius_norm())

    u.check_close(Ck.frobenius_norm(), 4 * math.sqrt(11635.))

    Ci = [[
        0, 5 / 102, -(7 / 204), 0, -(70 / 561), 49 / 561, 0, 125 / 1122,
        -(175 / 2244)
    ],
          [
              1 / 20, -(53 / 1020), 8 / 255, -(7 / 55), 371 / 2805,
              -(224 / 2805), 5 / 44, -(265 / 2244), 40 / 561
          ],
          [
              -(1 / 20), 3 / 170, 3 / 170, 7 / 55, -(42 / 935), -(42 / 935),
              -(5 / 44), 15 / 374, 15 / 374
          ],
          [
              0, -(5 / 102), 7 / 204, 0, 20 / 561, -(14 / 561), 0, 35 / 1122,
              -(49 / 2244)
          ],
          [
              -(1 / 20), 53 / 1020, -(8 / 255), 2 / 55, -(106 / 2805),
              64 / 2805, 7 / 220, -(371 / 11220), 56 / 2805
          ],
          [
              1 / 20, -(3 / 170), -(3 / 170), -(2 / 55), 12 / 935, 12 / 935,
              -(7 / 220), 21 / 1870, 21 / 1870
          ], [0, 5 / 102, -(7 / 204), 0, 0, 0, 0, -(5 / 102), 7 / 204],
          [
              1 / 20, -(53 / 1020), 8 / 255, 0, 0, 0, -(1 / 20), 53 / 1020,
              -(8 / 255)
          ],
          [
              -(1 / 20), 3 / 170, 3 / 170, 0, 0, 0, 1 / 20, -(3 / 170),
              -(3 / 170)
          ]]
    C = Ck.expand()
    C0 = u.to_numpy(C)
    Ci = torch.tensor(Ci)
    u.check_close(C @ Ci @ C, C)

    u.check_close(Ck.inv().expand(), torch.inverse(Ck.expand()))
    u.check_close(Ck.inv().expand_vec(), torch.inverse(Ck.expand_vec()))
    u.check_close(Ck.pinv().expand(), torch.pinverse(Ck.expand()))

    u.check_close(linalg.pinv(C0), Ci, rtol=1e-5, atol=1e-6)
    u.check_close(torch.pinverse(C), Ci, rtol=1e-5, atol=1e-6)
    u.check_close(Ck.inv().expand(), Ci, rtol=1e-5, atol=1e-6)
    u.check_close(Ck.pinv().expand(), Ci, rtol=1e-5, atol=1e-6)

    Ck2 = u.Kron(b, 2 * a)
    u.check_close((Ck @ Ck2).expand(), Ck.expand() @ Ck2.expand())
    u.check_close((Ck @ Ck2).expand_vec(), Ck.expand_vec() @ Ck2.expand_vec())

    d2 = 3
    d1 = 2
    G = torch.randn(d2, d1)
    g = u.vec(G)
    H = u.Kron(u.random_cov(d1), u.random_cov(d2))

    Gt = G.t()
    gt = g.reshape(1, -1)

    vecX = u.Vec([1, 2, 3, 4], shape=(2, 2))
    K = u.Kron([[5, 6], [7, 8]], [[9, 10], [11, 12]])

    u.check_equal(vecX @ K, [644, 706, 748, 820])
    u.check_equal(K @ vecX, [543, 655, 737, 889])

    u.check_equal(u.matmul(vecX @ K, vecX), 7538)
    u.check_equal(vecX @ (vecX @ K), 7538)
    u.check_equal(vecX @ vecX, 30)

    vecX = u.Vec([1, 2], shape=(1, 2))
    K = u.Kron([[5]], [[9, 10], [11, 12]])

    u.check_equal(vecX.norm()**2, 5)

    # check kronecker rules
    X = torch.tensor([[1., 2], [3, 4]])
    A = torch.tensor([[5., 6], [7, 8]])
    B = torch.tensor([[9., 10], [11, 12]])
    x = u.Vec(X)

    # kron/vec/vecr identities
    u.check_equal(u.Vec(A @ X @ B), x @ u.Kron(B, A.t()))
    u.check_equal(u.Vec(A @ X @ B), u.Kron(B.t(), A) @ x)
    u.check_equal(u.Vecr(A @ X @ B), u.Kron(A, B.t()) @ u.Vecr(X))
    u.check_equal(u.Vecr(A @ X @ B), u.Vecr(X) @ u.Kron(A.t(), B))

    def extra_checks(A, X, B):
        x = u.Vec(X)
        u.check_equal(u.Vec(A @ X @ B), x @ u.Kron(B, A.t()))
        u.check_equal(u.Vec(A @ X @ B), u.Kron(B.t(), A) @ x)
        u.check_equal(u.Vecr(A @ X @ B), u.Kron(A, B.t()) @ u.Vecr(X))
        u.check_equal(u.Vecr(A @ X @ B), u.Vecr(X) @ u.Kron(A.t(), B))
        u.check_equal(u.Vecr(A @ X @ B),
                      u.Vecr(X) @ u.Kron(A.t(), B).normal_form())
        u.check_equal(u.Vecr(A @ X @ B),
                      u.matmul(u.Kron(A, B.t()).normal_form(), u.Vecr(X)))
        u.check_equal(u.Vec(A @ X @ B),
                      u.matmul(u.Kron(B.t(), A).normal_form(), x))
        u.check_equal(u.Vec(A @ X @ B), x @ u.Kron(B, A.t()).normal_form())
        u.check_equal(u.Vec(A @ X @ B),
                      x.normal_form() @ u.Kron(B, A.t()).normal_form())
        u.check_equal(u.Vec(A @ X @ B),
                      u.Kron(B.t(), A).normal_form() @ x.normal_form())
        u.check_equal(u.Vecr(A @ X @ B),
                      u.Kron(A, B.t()).normal_form() @ u.Vecr(X).normal_form())
        u.check_equal(u.Vecr(A @ X @ B),
                      u.Vecr(X).normal_form() @ u.Kron(A.t(), B).normal_form())

    # shape checks
    d1, d2 = 3, 4
    extra_checks(torch.ones((d1, d1)), torch.ones((d1, d2)),
                 torch.ones((d2, d2)))

    A = torch.rand(d1, d1)
    B = torch.rand(d2, d2)
    #x = torch.rand((d1*d2))
    #X = x.t().reshape(d1, d2)
    # X = torch.rand((d1, d2))
    # x = u.vec(X)
    x = torch.rand((d1 * d2))
Exemplo n.º 26
0
def test_lyapunov():
    """Test that scipy lyapunov solver works correctly."""
    d = 2
    n = 3
    torch.set_default_dtype(torch.float32)

    model = Net(d)

    w0 = torch.tensor([[1, 2]]).float()
    assert w0.shape[1] == d
    model.w.weight.data.copy_(w0)

    X = torch.tensor([[-2, 0, 2], [-1, 1, 3]]).float()
    assert X.shape[0] == d
    assert X.shape[1] == n

    Y = torch.tensor([[0, 1, 2]]).float()
    assert Y.shape[1] == X.shape[1]

    data = X.t()  # PyTorch expects batch dimension first
    target = Y.t()
    assert data.shape[0] == n

    output = model(data)
    # residuals, aka e
    residuals = output - Y.t()

    def compute_loss(residuals_):
        return torch.sum(residuals_ * residuals_) / (2 * n)

    loss = compute_loss(residuals)

    assert loss - 8.83333 < 1e-5, torch.norm(loss) - 8.83333

    # use learning rate 0 to avoid changing parameter vector
    optim_kwargs = dict(
        lr=0,
        momentum=0,
        weight_decay=0,
        l2_reg=0,
        bias_correction=False,
        acc_steps=1,
        curv_type="Cov",
        curv_shapes={"Linear": "Kron"},
        momentum_type="preconditioned",
    )
    curv_args = dict(damping=1, ema_decay=1)  # todo: damping
    optimizer = SecondOrderOptimizer(model,
                                     **optim_kwargs,
                                     curv_kwargs=curv_args)

    def backward(last_layer: str) -> Callable:
        """Creates closure that backpropagates either from output layer or from loss layer"""
        def closure() -> Tuple[Optional[torch.Tensor], torch.Tensor]:
            optimizer.zero_grad()
            output = model(data)
            if last_layer == "output":
                output.backward(torch.ones_like(target))
                return None, output
            elif last_layer == 'loss':
                loss = compute_loss(output - target)
                loss.backward()
                return loss, output
            else:
                assert False, 'last layer must be "output" or "loss"'

        return closure

    #    loss = compute_loss(output - Y.t())
    #    loss.backward()

    loss, output = optimizer.step(closure=backward('loss'))
    J = X.t()
    A = model.w.data_input
    B = model.w.grad_output * n
    G = residuals.repeat(1, d) * J
    losses = torch.stack([compute_loss(r) for r in residuals])
    g = G.sum(dim=0) / n
    efisher = G.t() @ G / n
    sigma = efisher - u.outer(g, g)
    loss2 = (residuals * residuals).sum() / (2 * n)
    H = J.t() @ J / n
    noise_variance = torch.trace(H.inverse() @ sigma)

    # H is not quite symmetric, make it so
    H = H + H.t()

    # Slow way
    p_sigma = u.lyapunov_lstsq(H, sigma)

    sigma0 = u.to_numpy(sigma)
    H0 = u.to_numpy(H)

    # Alternative faster way
    p_sigma2 = scipy.linalg.solve_lyapunov(H0, sigma0)
    print(f"Error 1: {np.max(abs(H0 @ p_sigma2 + p_sigma2 @ H0 - sigma0))}")
    u.check_close(p_sigma, p_sigma2)

    # alternative through SVD
    p_sigma3 = lyapunov_svd(torch.tensor(H0), torch.tensor(sigma0))
    u.check_close(p_sigma2, p_sigma3)

    # alternative through evals
    p_sigma4 = u.lyapunov_spectral(torch.tensor(H0), torch.tensor(sigma0))
    u.check_close(p_sigma2, p_sigma4)
 def train_one_step(self, data_i, total_steps_so_far):
     images_minibatch, labels = self.prepare_images(data_i)
     c_losses = self.train_classifier_one_step(images_minibatch, labels)
     self.adjust_lr_if_necessary(total_steps_so_far)
     return util.to_numpy(c_losses)
Exemplo n.º 28
0
def main():
    np.random.seed(42)

    parser = ArgumentParser()
    parser.add_argument('--base_dir', type=str, default='./EDD2020/')  # './'
    parser.add_argument('--model', type=str, default='unetresnet')
    args = parser.parse_args()

    base_dir = args.base_dir
    model_name = args.model

    success = create_dir(base_dir + 'resized_masks/')
    if success:
        resize_my_images(base_dir + 'EDD2020_release-I_2020-01-15/masks/',
                         base_dir + 'resized_masks/',
                         is_masks=True)

    success = create_dir(base_dir + 'resized_images/')
    success &= create_dir(base_dir + 'resized_bboxs/')
    if success:
        resize_my_images(
            base_dir + 'EDD2020_release-I_2020-01-15/originalImages/',
            base_dir + 'resized_images/',
            is_masks=False,
            bboxs_src=base_dir + 'EDD2020_release-I_2020-01-15/bbox/',
            bboxs_dst=base_dir + 'resized_bboxs/')

    loader = get_edd_loader(base_dir, shuffle_dataset=True)

    model = get_model(model_name, in_channel=3, n_classes=5)
    optimizer_func = optim.Adam(model.parameters(), lr=1e-4)
    scheduler = lr_scheduler.StepLR(optimizer_func, step_size=10, gamma=0.1)

    trainer = Trainer(model, optimizer=optimizer_func, scheduler=scheduler)
    trainer.train_model(loader, num_epochs=30)

    create_dir(base_dir + 'test/')
    create_dir(base_dir + 'test/images')
    create_dir(base_dir + 'test/masks')
    create_dir(base_dir + 'test/masks_pred')
    create_dir(base_dir + 'test/bboxs')
    create_dir(base_dir + 'test/bboxs_pred')
    create_dir(base_dir + 'test/plots')

    metrics = defaultdict(float)
    plot = Plot(base_dir + 'test/plots/')

    index = 0
    for epoch, (images, bboxs, masks) in enumerate(loader['test']):
        masks_preds = trainer.predict(images)

        # print('B ', bboxs)
        bboxs = bbox_tensor_to_bbox(bboxs.squeeze(0))
        # print('A ', bboxs)

        bboxs_preds = compute_bboxs_from_masks(masks_preds.squeeze(0))

        plot.plot_image_truemask_predictedmask(images, masks, masks_preds,
                                               index)
        plot.plot_image_truebbox_predictedbbox(images, bboxs, bboxs_preds,
                                               index)

        calc_loss(torch.Tensor(masks_preds), masks, metrics)

        image, mask, mask_pred = images[0], masks[0], masks_preds[0]

        save_to_tif(base_dir + 'test/masks/MASK_{:03d}.tif'.format(index),
                    to_numpy(mask))

        save_to_tif(
            base_dir + 'test/masks_pred/MASK_PRED_{:03d}.tif'.format(index),
            mask_pred)

        image = image.detach().cpu().numpy().swapaxes(0, 2).swapaxes(0, 1)
        plt.imsave(base_dir + 'test/images/IMG_{:03d}.png'.format(index),
                   image)

        index += 1

    computed_metrics = compute_metrics(metrics, epoch + 1)
    print_metrics(computed_metrics, 'test')
Exemplo n.º 29
0
def main():
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=10, metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model', action='store_true', default=False,
                        help='For Saving the current Model')

    parser.add_argument('--wandb', type=int, default=0, help='log to weights and biases')
    parser.add_argument('--autograd_check', type=int, default=0, help='autograd correctness checks')
    parser.add_argument('--logdir', type=str, default='/tmp/runs/curv_train_tiny/run')

    parser.add_argument('--nonlin', type=int, default=1, help="whether to add ReLU nonlinearity between layers")
    parser.add_argument('--bias', type=int, default=1, help="whether to add bias between layers")

    parser.add_argument('--layer', type=int, default=-1, help="restrict updates to this layer")
    parser.add_argument('--data_width', type=int, default=28)
    parser.add_argument('--targets_width', type=int, default=28)
    parser.add_argument('--hess_samples', type=int, default=1,
                        help='number of samples when sub-sampling outputs, 0 for exact hessian')
    parser.add_argument('--hess_kfac', type=int, default=0, help='whether to use KFAC approximation for hessian')
    parser.add_argument('--compute_rho', type=int, default=0, help='use expensive method to compute rho')
    parser.add_argument('--skip_stats', type=int, default=0, help='skip all stats collection')

    parser.add_argument('--dataset_size', type=int, default=60000)
    parser.add_argument('--train_steps', type=int, default=100, help="this many train steps between stat collection")
    parser.add_argument('--stats_steps', type=int, default=1000000, help="total number of curvature stats collections")

    parser.add_argument('--full_batch', type=int, default=0, help='do stats on the whole dataset')
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--weight_decay', type=float, default=0)
    parser.add_argument('--momentum', type=float, default=0.9)
    parser.add_argument('--dropout', type=int, default=0)
    parser.add_argument('--swa', type=int, default=0)
    parser.add_argument('--lmb', type=float, default=1e-3)

    parser.add_argument('--train_batch_size', type=int, default=64)
    parser.add_argument('--stats_batch_size', type=int, default=10000)
    parser.add_argument('--stats_num_batches', type=int, default=1)
    parser.add_argument('--run_name', type=str, default='noname')
    parser.add_argument('--launch_blocking', type=int, default=0)
    parser.add_argument('--sampled', type=int, default=0)
    parser.add_argument('--curv', type=str, default='kfac',
                        help='decomposition to use for curvature estimates: zero_order, kfac, isserlis or full')
    parser.add_argument('--log_spectra', type=int, default=0)

    u.seed_random(1)
    gl.args = parser.parse_args()
    args = gl.args
    u.seed_random(1)

    gl.project_name = 'train_ciresan'
    u.setup_logdir_and_event_writer(args.run_name)
    print(f"Logging to {gl.logdir}")

    d1 = 28 * 28
    d = [784, 2500, 2000, 1500, 1000, 500, 10]

    # number of samples per datapoint. Used to normalize kfac
    model = u.SimpleFullyConnected2(d, nonlin=args.nonlin, bias=args.bias, dropout=args.dropout)
    model = model.to(gl.device)
    autograd_lib.register(model)

    assert args.dataset_size >= args.stats_batch_size
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
    dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, original_targets=True,
                          dataset_size=args.dataset_size)
    train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True, drop_last=True)
    train_iter = u.infinite_iter(train_loader)

    assert not args.full_batch, "fixme: validation still uses stats_iter"
    if not args.full_batch:
        stats_loader = torch.utils.data.DataLoader(dataset, batch_size=args.stats_batch_size, shuffle=True,
                                                   drop_last=True)
        stats_iter = u.infinite_iter(stats_loader)
    else:
        stats_iter = None

    test_dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, train=False,
                               original_targets=True,
                               dataset_size=args.dataset_size)
    test_eval_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.stats_batch_size, shuffle=False,
                                                   drop_last=False)
    train_eval_loader = torch.utils.data.DataLoader(dataset, batch_size=args.stats_batch_size, shuffle=False,
                                                    drop_last=False)

    loss_fn = torch.nn.CrossEntropyLoss()
    autograd_lib.add_hooks(model)
    autograd_lib.disable_hooks()

    gl.token_count = 0
    last_outer = 0

    for step in range(args.stats_steps):
        epoch = gl.token_count // 60000
        lr = optimizer.param_groups[0]['lr']
        print('token_count', gl.token_count)
        if last_outer:
            u.log_scalars({"time/outer": 1000 * (time.perf_counter() - last_outer)})
            print(f'time: {time.perf_counter() - last_outer:.2f}')
        last_outer = time.perf_counter()

        with u.timeit("validate"):
            val_accuracy, val_loss = validate(model, test_eval_loader, f'test (epoch {epoch})')
            train_accuracy, train_loss = validate(model, train_eval_loader, f'train (epoch {epoch})')

        # save log
        metrics = {'epoch': epoch, 'val_accuracy': val_accuracy, 'val_loss': val_loss,
                   'train_loss': train_loss, 'train_accuracy': train_accuracy,
                   'lr': optimizer.param_groups[0]['lr'],
                   'momentum': optimizer.param_groups[0].get('momentum', 0)}
        u.log_scalars(metrics)

        def mom_update(buffer, val):
            buffer *= 0.9
            buffer += val * 0.1

        if not args.skip_stats:
            # number of samples passed through
            n = args.stats_batch_size * args.stats_num_batches

            # quanti
            forward_stats = defaultdict(lambda: AttrDefault(float))

            hessians = defaultdict(lambda: AttrDefault(float))
            jacobians = defaultdict(lambda: AttrDefault(float))
            fishers = defaultdict(lambda: AttrDefault(float))  # empirical fisher/gradient
            quad_fishers = defaultdict(lambda: AttrDefault(float))  # gradient statistics that depend on fisher (4th order moments)
            train_regrets = defaultdict(list)
            test_regrets1 = defaultdict(list)
            test_regrets2 = defaultdict(list)
            train_regrets_opt = defaultdict(list)
            test_regrets_opt = defaultdict(list)
            cosines = defaultdict(list)
            dot_products = defaultdict(list)
            hessians_histograms = defaultdict(lambda: AttrDefault(u.MyList))
            jacobians_histograms = defaultdict(lambda: AttrDefault(u.MyList))
            fishers_histograms = defaultdict(lambda: AttrDefault(u.MyList))
            quad_fishers_histograms = defaultdict(lambda: AttrDefault(u.MyList))

            current = None
            current_histograms = None

            for i in range(args.stats_num_batches):
                activations = {}
                backprops = {}

                def save_activations(layer, A, _):
                    activations[layer] = A
                    forward_stats[layer].AA += torch.einsum("ni,nj->ij", A, A)

                print('forward')
                with u.timeit("stats_forward"):
                    with autograd_lib.module_hook(save_activations):
                        data, targets = next(stats_iter)
                        output = model(data)
                        loss = loss_fn(output, targets) * len(output)

                def compute_stats(layer, _, B):
                    A = activations[layer]
                    if current == fishers:
                        backprops[layer] = B

                    # about 27ms per layer
                    with u.timeit('compute_stats'):
                        current[layer].BB += torch.einsum("ni,nj->ij", B, B)  # TODO(y): index consistency
                        current[layer].diag += torch.einsum("ni,nj->ij", B * B, A * A)
                        current[layer].BA += torch.einsum("ni,nj->ij", B, A)
                        current[layer].a += torch.einsum("ni->i", A)
                        current[layer].b += torch.einsum("nk->k", B)
                        current[layer].norm2 += ((A * A).sum(dim=1) * (B * B).sum(dim=1)).sum()

                        # compute curvatures in direction of all gradiennts
                        if current is fishers:
                            assert args.stats_num_batches == 1, "not tested on more than one stats step, currently reusing aggregated moments"
                            hess = hessians[layer]
                            jac = jacobians[layer]
                            Bh, Ah = B @ hess.BB / n, A @ forward_stats[layer].AA / n
                            Bj, Aj = B @ jac.BB / n, A @ forward_stats[layer].AA / n
                            norms = ((A * A).sum(dim=1) * (B * B).sum(dim=1))

                            current[layer].min_norm2 = min(norms)
                            current[layer].median_norm2 = torch.median(norms)
                            current[layer].max_norm2 = max(norms)

                            norms2_hess = ((Ah * A).sum(dim=1) * (Bh * B).sum(dim=1))
                            norms2_jac = ((Aj * A).sum(dim=1) * (Bj * B).sum(dim=1))

                            current[layer].norm += norms.sum()
                            current_histograms[layer].norms.extend(torch.sqrt(norms))
                            current[layer].curv_hess += (skip_nans(norms2_hess / norms)).sum()
                            current_histograms[layer].curv_hess.extend(skip_nans(norms2_hess / norms))
                            current[layer].curv_hess_max += (skip_nans(norms2_hess / norms)).max()
                            current[layer].curv_hess_median += (skip_nans(norms2_hess / norms)).median()

                            current_histograms[layer].curv_jac.extend(skip_nans(norms2_jac / norms))
                            current[layer].curv_jac += (skip_nans(norms2_jac / norms)).sum()
                            current[layer].curv_jac_max += (skip_nans(norms2_jac / norms)).max()
                            current[layer].curv_jac_median += (skip_nans(norms2_jac / norms)).median()

                            current[layer].a_sparsity += torch.sum(A <= 0).float() / A.numel()
                            current[layer].b_sparsity += torch.sum(B <= 0).float() / B.numel()

                            current[layer].mean_activation += torch.mean(A)
                            current[layer].mean_activation2 += torch.mean(A*A)
                            current[layer].mean_backprop = torch.mean(B)
                            current[layer].mean_backprop2 = torch.mean(B*B)

                            current[layer].norms_hess += torch.sqrt(norms2_hess).sum()
                            current_histograms[layer].norms_hess.extend(torch.sqrt(norms2_hess))
                            current[layer].norms_jac += norms2_jac.sum()
                            current_histograms[layer].norms_jac.extend(torch.sqrt(norms2_jac))

                            normalized_moments = copy.copy(hessians[layer])
                            normalized_moments.AA = forward_stats[layer].AA
                            normalized_moments = u.divide_attributes(normalized_moments, n)

                            train_regrets_ = autograd_lib.offset_losses(A, B, alpha=lr, offset=0, m=normalized_moments,
                                                                        approx=args.curv)
                            test_regrets1_ = autograd_lib.offset_losses(A, B, alpha=lr, offset=1, m=normalized_moments,
                                                                        approx=args.curv)
                            test_regrets2_ = autograd_lib.offset_losses(A, B, alpha=lr, offset=2, m=normalized_moments,
                                                                        approx=args.curv)
                            test_regrets_opt_ = autograd_lib.offset_losses(A, B, alpha=None, offset=2,
                                                                           m=normalized_moments, approx=args.curv)
                            train_regrets_opt_ = autograd_lib.offset_losses(A, B, alpha=None, offset=0,
                                                                            m=normalized_moments, approx=args.curv)
                            cosines_ = autograd_lib.offset_cosines(A, B)
                            train_regrets[layer].extend(train_regrets_)
                            test_regrets1[layer].extend(test_regrets1_)
                            test_regrets2[layer].extend(test_regrets2_)
                            train_regrets_opt[layer].extend(train_regrets_opt_)
                            test_regrets_opt[layer].extend(test_regrets_opt_)
                            cosines[layer].extend(cosines_)
                            dot_products[layer].extend(autograd_lib.offset_dotprod(A, B))

                        # statistics of the form g.Sigma.g
                        elif current == quad_fishers:
                            hess = hessians[layer]
                            sigma = fishers[layer]
                            jac = jacobians[layer]
                            Bs, As = B @ sigma.BB / n, A @ forward_stats[layer].AA / n
                            Bh, Ah = B @ hess.BB / n, A @ forward_stats[layer].AA / n
                            Bj, Aj = B @ jac.BB / n, A @ forward_stats[layer].AA / n

                            norms = ((A * A).sum(dim=1) * (B * B).sum(dim=1))
                            norms2_hess = ((Ah * A).sum(dim=1) * (Bh * B).sum(dim=1))
                            norms2_jac = ((Aj * A).sum(dim=1) * (Bj * B).sum(dim=1))
                            norms_sigma = ((As * A).sum(dim=1) * (Bs * B).sum(dim=1))

                            current[layer].norm += norms.sum()  # TODO(y) remove, redundant with norm2 above
                            current[layer].curv_sigma += (skip_nans(norms_sigma / norms)).sum()
                            current[layer].curv_sigma_max = skip_nans(norms_sigma / norms).max()
                            current[layer].curv_sigma_median = skip_nans(norms_sigma / norms).median()
                            current[layer].curv_hess += skip_nans(norms2_hess / norms).sum()
                            current[layer].curv_hess_max += skip_nans(norms2_hess / norms).max()
                            current[layer].lyap_hess_mean += skip_nans(norms_sigma / norms2_hess).mean()
                            current[layer].lyap_hess_max = max(skip_nans(norms_sigma/norms2_hess))
                            current[layer].lyap_jac_mean += skip_nans(norms_sigma / norms2_jac).mean()
                            current[layer].lyap_jac_max = max(skip_nans(norms_sigma/norms2_jac))

                print('backward')
                with u.timeit("backprop_H"):
                    with autograd_lib.module_hook(compute_stats):
                        current = hessians
                        current_histograms = hessians_histograms
                        autograd_lib.backward_hessian(output, loss='CrossEntropy', sampled=args.sampled,
                                                      retain_graph=True)  # 600 ms
                        current = jacobians
                        current_histograms = jacobians_histograms
                        autograd_lib.backward_jacobian(output, sampled=args.sampled, retain_graph=True)  # 600 ms
                        current = fishers
                        current_histograms = fishers_histograms
                        model.zero_grad()
                        loss.backward(retain_graph=True)  # 60 ms
                        current = quad_fishers
                        current_histograms = quad_fishers_histograms
                        model.zero_grad()
                        loss.backward()  # 60 ms

            print('summarize')
            for (i, layer) in enumerate(model.layers):
                stats_dict = {'hessian': hessians, 'jacobian': jacobians, 'fisher': fishers}

                # evaluate stats from
                # https://app.wandb.ai/yaroslavvb/train_ciresan/runs/425pu650?workspace=user-yaroslavvb
                for stats_name in stats_dict:
                    s = AttrDict()
                    stats = stats_dict[stats_name][layer]

                    for key in forward_stats[layer]:
                        # print(f'copying {key} in {stats_name}, {layer}')
                        try:
                            assert stats[key] == float()
                        except:
                            f"Trying to overwrite {key} in {stats_name}, {layer}"
                        stats[key] = forward_stats[layer][key]

                    diag: torch.Tensor = stats.diag / n

                    # jacobian:
                    # curv in direction of gradient goes down to roughly 0.3-1
                    # maximum curvature goes up to 1000-2000
                    #
                    # Hessian:
                    # max curv goes down to 1, in direction of gradient 0.0001

                    s.diag_l2 = torch.max(diag)  # 40 - 3000 smaller than kfac l2 for jac
                    s.diag_fro = torch.norm(
                        diag)  # jacobian grows to 0.5-1.5, rest falls, layer-5 has phase transition, layer-4 also
                    s.diag_trace = diag.sum()  # jacobian grows 0-1000 (first), 0-150 (last). Almost same as kfac_trace (771 vs 810 kfac). Jacobian has up/down phase transition
                    s.diag_average = diag.mean()

                    # normalize for mean loss
                    BB = stats.BB / n
                    AA = stats.AA / n
                    # A_evals, _ = torch.symeig(AA)   # averaging 120ms per hit, 90 hits
                    # B_evals, _ = torch.symeig(BB)

                    # s.kfac_l2 = torch.max(A_evals) * torch.max(B_evals)    # 60x larger than diag_l2. layer0/hess has down/up phase transition. layer5/jacobian has up/down phase transition
                    s.kfac_trace = torch.trace(AA) * torch.trace(BB)  # 0/hess down/up tr, 5/jac sharp phase transition
                    s.kfac_fro = torch.norm(stats.AA) * torch.norm(
                        stats.BB)  # 0/hess has down/up tr, 5/jac up/down transition
                    # s.kfac_erank = s.kfac_trace / s.kfac_l2   # first layer has 25, rest 15, all layers go down except last, last noisy
                    # s.kfac_erank_fro = s.kfac_trace / s.kfac_fro / max(stats.BA.shape)

                    s.diversity = (stats.norm2 / n) / u.norm_squared(
                        stats.BA / n)  # gradient diversity. Goes up 3x. Bottom layer has most diversity. Jacobian diversity much less noisy than everythingelse

                    # discrepancy of KFAC based on exact values of diagonal approximation
                    # average difference normalized by average diagonal magnitude
                    diag_kfac = torch.einsum('ll,ii->li', BB, AA)
                    s.kfac_error = (torch.abs(diag_kfac - diag)).mean() / torch.mean(diag.abs())
                    u.log_scalars(u.nest_stats(f'layer-{i}/{stats_name}', s))

                # openai batch size stat
                s = AttrDict()
                hess = hessians[layer]
                jac = jacobians[layer]
                fish = fishers[layer]
                quad_fish = quad_fishers[layer]

                # the following check passes, but is expensive
                # if args.stats_num_batches == 1:
                #    u.check_close(fisher[layer].BA, layer.weight.grad)

                def trsum(A, B):
                    return (A * B).sum()  # computes tr(AB')

                grad = fishers[layer].BA / n
                s.grad_fro = torch.norm(grad)

                # get norms
                s.lyap_hess_max = quad_fish.lyap_hess_max
                s.lyap_hess_ave = quad_fish.lyap_hess_sum / n
                s.lyap_jac_max = quad_fish.lyap_jac_max
                s.lyap_jac_ave = quad_fish.lyap_jac_sum / n
                s.hess_trace = hess.diag.sum() / n
                s.jac_trace = jac.diag.sum() / n

                # Version 1 of Jain stochastic rates, use Hessian for curvature
                b = args.train_batch_size

                s.hess_curv = trsum((hess.BB / n) @ grad @ (hess.AA / n), grad) / trsum(grad, grad)
                s.jac_curv = trsum((jac.BB / n) @ grad @ (jac.AA / n), grad) / trsum(grad, grad)

                # compute gradient noise statistics
                # fish.BB has /n factor twice, hence don't need extra /n on fish.AA
                # after sampling, hess_noise,jac_noise became 100x smaller, but normalized is unaffected
                s.hess_noise = (trsum(hess.AA / n, fish.AA / n) * trsum(hess.BB / n, fish.BB / n))
                s.jac_noise = (trsum(jac.AA / n, fish.AA / n) * trsum(jac.BB / n, fish.BB / n))
                s.hess_noise_centered = s.hess_noise - trsum(hess.BB / n @ grad, grad @ hess.AA / n)
                s.jac_noise_centered = s.jac_noise - trsum(jac.BB / n @ grad, grad @ jac.AA / n)
                s.openai_gradient_noise = (fish.norms_hess / n) / trsum(hess.BB / n @ grad, grad @ hess.AA / n)

                s.mean_norm = torch.sqrt(fish.norm2) / n
                s.min_norm = torch.sqrt(fish.min_norm2)
                s.median_norm = torch.sqrt(fish.median_norm2)
                s.max_norm = torch.sqrt(fish.max_norm2)
                s.enorms = u.norm_squared(grad)
                s.a_sparsity = fish.a_sparsity
                s.b_sparsity = fish.b_sparsity
                s.mean_activation = fish.mean_activation
                s.msr_activation = torch.sqrt(fish.mean_activation2)
                s.mean_backprop = fish.mean_backprop
                s.msr_backprop = torch.sqrt(fish.mean_backprop2)

                s.norms_centered = fish.norm2 / n - u.norm_squared(grad)
                s.norms_hess = fish.norms_hess / n
                s.norms_jac = fish.norms_jac / n

                s.hess_curv_grad = fish.curv_hess / n  # phase transition, hits minimum loss in layer 1, then starts going up. Other layers take longer to reach minimum. Decreases with depth.
                s.hess_curv_grad_max = fish.curv_hess_max   # phase transition, hits minimum loss in layer 1, then starts going up. Other layers take longer to reach minimum. Decreases with depth.
                s.hess_curv_grad_median = fish.curv_hess_median   # phase transition, hits minimum loss in layer 1, then starts going up. Other layers take longer to reach minimum. Decreases with depth.
                s.sigma_curv_grad = quad_fish.curv_sigma / n
                s.sigma_curv_grad_max = quad_fish.curv_sigma_max
                s.sigma_curv_grad_median = quad_fish.curv_sigma_median
                s.band_bottou = 0.5 * lr * s.sigma_curv_grad / s.hess_curv_grad
                s.band_bottou_stoch = 0.5 * lr * quad_fish.curv_ratio / n
                s.band_yaida = 0.25 * lr * s.mean_norm**2
                s.band_yaida_centered = 0.25 * lr * s.norms_centered

                s.jac_curv_grad = fish.curv_jac / n  # this one has much lower variance than jac_curv. Reaches peak at 10k steps, also kfac error reaches peak there. Decreases with depth except for last layer.
                s.jac_curv_grad_max = fish.curv_jac_max  # this one has much lower variance than jac_curv. Reaches peak at 10k steps, also kfac error reaches peak there. Decreases with depth except for last layer.
                s.jac_curv_grad_median = fish.curv_jac_median  # this one has much lower variance than jac_curv. Reaches peak at 10k steps, also kfac error reaches peak there. Decreases with depth except for last layer.

                # OpenAI gradient noise statistics
                s.hess_noise_normalized = s.hess_noise_centered / (fish.norms_hess / n)
                s.jac_noise_normalized = s.jac_noise / (fish.norms_jac / n)

                train_regrets_, test_regrets1_, test_regrets2_, train_regrets_opt_, test_regrets_opt_, cosines_, dot_products_ = (torch.stack(r[layer]) for r in (train_regrets, test_regrets1, test_regrets2, train_regrets_opt, test_regrets_opt, cosines, dot_products))
                s.train_regret = train_regrets_.median()  # use median because outliers make it hard to see the trend
                s.test_regret1 = test_regrets1_.median()
                s.test_regret2 = test_regrets2_.median()
                s.test_regret_opt = test_regrets_opt_.median()
                s.train_regret_opt = train_regrets_opt_.median()
                s.mean_dot_product = torch.mean(dot_products_)
                s.median_dot_product = torch.median(dot_products_)
                a = [1, 2, 3]

                s.median_cosine = cosines_.median()
                s.mean_cosine = cosines_.mean()

                # get learning rates
                L1 = s.hess_curv_grad / n
                L2 = s.jac_curv_grad / n
                diversity = (fish.norm2 / n) / u.norm_squared(grad)
                robust_diversity = (fish.norm2 / n) / fish.median_norm2
                dotprod_diversity = fish.median_norm2 / s.median_dot_product
                s.lr1 = 2 / (L1 * diversity)
                s.lr2 = 2 / (L2 * diversity)
                s.lr3 = 2 / (L2 * robust_diversity)
                s.lr4 = 2 / (L2 * dotprod_diversity)

                hess_A = u.symeig_pos_evals(hess.AA / n)
                hess_B = u.symeig_pos_evals(hess.BB / n)
                fish_A = u.symeig_pos_evals(fish.AA / n)
                fish_B = u.symeig_pos_evals(fish.BB / n)
                jac_A = u.symeig_pos_evals(jac.AA / n)
                jac_B = u.symeig_pos_evals(jac.BB / n)
                u.log_scalars({f'layer-{i}/hessA_erank': erank(hess_A)})
                u.log_scalars({f'layer-{i}/hessB_erank': erank(hess_B)})
                u.log_scalars({f'layer-{i}/fishA_erank': erank(fish_A)})
                u.log_scalars({f'layer-{i}/fishB_erank': erank(fish_B)})
                u.log_scalars({f'layer-{i}/jacA_erank': erank(jac_A)})
                u.log_scalars({f'layer-{i}/jacB_erank': erank(jac_B)})
                gl.event_writer.add_histogram(f'layer-{i}/hist_hess_eig', u.outer(hess_A, hess_B).flatten(), gl.get_global_step())
                gl.event_writer.add_histogram(f'layer-{i}/hist_fish_eig', u.outer(hess_A, hess_B).flatten(), gl.get_global_step())
                gl.event_writer.add_histogram(f'layer-{i}/hist_jac_eig', u.outer(hess_A, hess_B).flatten(), gl.get_global_step())

                s.hess_l2 = max(hess_A) * max(hess_B)
                s.jac_l2 = max(jac_A) * max(jac_B)
                s.fish_l2 = max(fish_A) * max(fish_B)
                s.hess_trace = hess.diag.sum() / n

                s.jain1_sto = 1/(s.hess_trace + 2 * s.hess_l2)
                s.jain1_det = 1/s.hess_l2

                s.jain1_lr = (1 / b) * (1/s.jain1_sto) + (b - 1) / b * (1/s.jain1_det)
                s.jain1_lr = 2 / s.jain1_lr

                s.regret_ratio = (
                            train_regrets_opt_ / test_regrets_opt_).median()  # ratio between train and test regret, large means overfitting
                u.log_scalars(u.nest_stats(f'layer-{i}', s))

                # compute stats that would let you bound rho
                if i == 0:  # only compute this once, for output layer
                    hhh = hessians[model.layers[-1]].BB / n
                    fff = fishers[model.layers[-1]].BB / n
                    d = fff.shape[0]
                    L = u.lyapunov_spectral(hhh, 2 * fff, cond=1e-8)
                    L_evals = u.symeig_pos_evals(L)
                    Lcheap = fff @ u.pinv(hhh, cond=1e-8)
                    Lcheap_evals = u.eig_real(Lcheap)

                    u.log_scalars({f'mismatch/rho': d/erank(L_evals)})
                    u.log_scalars({f'mismatch/rho_cheap': d/erank(Lcheap_evals)})
                    u.log_scalars({f'mismatch/diagonalizability': erank(L_evals)/erank(Lcheap_evals)})  # 1 means diagonalizable
                    u.log_spectrum(f'mismatch/sigma', u.symeig_pos_evals(fff), loglog=False)
                    u.log_spectrum(f'mismatch/hess', u.symeig_pos_evals(hhh), loglog=False)
                    u.log_spectrum(f'mismatch/lyapunov', L_evals, loglog=True)
                    u.log_spectrum(f'mismatch/lyapunov_cheap', Lcheap_evals, loglog=True)

                gl.event_writer.add_histogram(f'layer-{i}/hist_grad_norms', u.to_numpy(fishers_histograms[layer].norms.value()), gl.get_global_step())
                gl.event_writer.add_histogram(f'layer-{i}/hist_grad_norms_hess', u.to_numpy(fishers_histograms[layer].norms_hess.value()), gl.get_global_step())
                gl.event_writer.add_histogram(f'layer-{i}/hist_curv_jac', u.to_numpy(fishers_histograms[layer].curv_jac.value()), gl.get_global_step())
                gl.event_writer.add_histogram(f'layer-{i}/hist_curv_hess', u.to_numpy(fishers_histograms[layer].curv_hess.value()), gl.get_global_step())
                gl.event_writer.add_histogram(f'layer-{i}/hist_cosines', u.to_numpy(cosines[layer]), gl.get_global_step())

                if args.log_spectra:
                    with u.timeit('spectrum'):
                        # 2/alpha
                        # s.jain1_lr = (1 / b) * s.jain1_sto + (b - 1) / b * s.jain1_det
                        # s.jain1_lr = 1 / s.jain1_lr

                        # hess.diag_trace, jac.diag_trace

                        # Version 2 of Jain stochastic rates, use Jacobian squared for curvature
                        s.jain2_sto = s.lyap_jac_max * s.jac_trace / s.lyap_jac_ave
                        s.jain2_det = s.jac_l2
                        s.jain2_lr = (1 / b) * s.jain2_sto + (b - 1) / b * s.jain2_det
                        s.jain2_lr = 1 / s.jain2_lr

                        u.log_spectrum(f'layer-{i}/hess_A', hess_A)
                        u.log_spectrum(f'layer-{i}/hess_B', hess_B)
                        u.log_spectrum(f'layer-{i}/hess_AB', u.outer(hess_A, hess_B).flatten())
                        u.log_spectrum(f'layer-{i}/jac_A', jac_A)
                        u.log_spectrum(f'layer-{i}/jac_B', jac_B)
                        u.log_spectrum(f'layer-{i}/fish_A', fish_A)
                        u.log_spectrum(f'layer-{i}/fish_B', fish_B)

                        u.log_scalars({f'layer-{i}/trace_ratio': fish_B.sum()/hess_B.sum()})

                        L = torch.eig(u.lyapunov_spectral(hess.BB, 2*fish.BB, cond=1e-8))[0]
                        L = L[:, 0]  # extract real part
                        L = L.sort()[0]
                        L = torch.flip(L, [0])

                        L_cheap = torch.eig(fish.BB @ u.pinv(hess.BB, cond=1e-8))[0]
                        L_cheap = L_cheap[:, 0]  # extract real part
                        L_cheap = L_cheap.sort()[0]
                        L_cheap = torch.flip(L_cheap, [0])

                        d = len(hess_B)
                        u.log_spectrum(f'layer-{i}/Lyap', L)
                        u.log_spectrum(f'layer-{i}/Lyap_cheap', L_cheap)

                        u.log_scalars({f'layer-{i}/dims': d})
                        u.log_scalars({f'layer-{i}/L_erank': erank(L)})
                        u.log_scalars({f'layer-{i}/L_cheap_erank': erank(L_cheap)})

                        u.log_scalars({f'layer-{i}/rho': d/erank(L)})
                        u.log_scalars({f'layer-{i}/rho_cheap': d/erank(L_cheap)})

        model.train()
        with u.timeit('train'):
            for i in range(args.train_steps):
                optimizer.zero_grad()
                data, targets = next(train_iter)
                model.zero_grad()
                output = model(data)
                loss = loss_fn(output, targets)
                loss.backward()

                optimizer.step()
                if args.weight_decay:
                    for group in optimizer.param_groups:
                        for param in group['params']:
                            param.data.mul_(1 - args.weight_decay)

                gl.token_count += data.shape[0]

    gl.event_writer.close()