Exemple #1
0
def evaluate(model, config, val_loader):
    model.eval()
    opt = config['opt']
    pad_label_id = config['pad_label_id']

    eval_loss = 0.
    criterion = nn.CrossEntropyLoss(ignore_index=pad_label_id).to(opt.device)
    n_batches = len(val_loader)
    prog = Progbar(target=n_batches)
    preds = None
    ys = None
    with torch.no_grad():
        for i, (x, y) in enumerate(val_loader):
            x = to_device(x, opt.device)
            y = to_device(y, opt.device)
            if opt.use_crf:
                logits, prediction = model(x)
                mask = torch.sign(torch.abs(x[0])).to(torch.uint8).to(
                    opt.device)
                log_likelihood = model.crf(logits,
                                           y,
                                           mask=mask,
                                           reduction='mean')
                loss = -1 * log_likelihood
            else:
                logits = model(x)
                loss = criterion(logits.view(-1, model.label_size), y.view(-1))
            if preds is None:
                if opt.use_crf: preds = to_numpy(prediction)
                else: preds = to_numpy(logits)
                ys = to_numpy(y)
            else:
                if opt.use_crf:
                    preds = np.append(preds, to_numpy(prediction), axis=0)
                else:
                    preds = np.append(preds, to_numpy(logits), axis=0)
                ys = np.append(ys, to_numpy(y), axis=0)
            eval_loss += loss.item()
            prog.update(i + 1, [('eval curr loss', loss.item())])
    eval_loss = eval_loss / n_batches
    if not opt.use_crf: preds = np.argmax(preds, axis=2)
    # compute measure using seqeval
    labels = model.labels
    ys_lbs = [[] for _ in range(ys.shape[0])]
    preds_lbs = [[] for _ in range(ys.shape[0])]
    for i in range(ys.shape[0]):  # foreach sentence
        for j in range(ys.shape[1]):  # foreach token
            if ys[i][j] != pad_label_id:
                ys_lbs[i].append(labels[ys[i][j]])
                preds_lbs[i].append(labels[preds[i][j]])
    ret = {
        "loss": eval_loss,
        "precision": precision_score(ys_lbs, preds_lbs),
        "recall": recall_score(ys_lbs, preds_lbs),
        "f1": f1_score(ys_lbs, preds_lbs),
        "report": classification_report(ys_lbs, preds_lbs, digits=4),
    }
    print(ret['report'])
    return ret
Exemple #2
0
    def forward_core(self, input):
        initial_h = to_device(torch.zeros(input.size(1), self.hidden_size))
        initial_c = to_device(torch.zeros(input.size(1), self.hidden_size))

        # lstm and dense pass for prediction
        lstm_out = self.lstm(input, initial_h, initial_c)[0]
        dense_out = self.forward_dense(lstm_out)
        return dense_out
Exemple #3
0
def evaluate(model, config, valid_loader, eval_device=None):
    args = config['args']

    total_loss = 0.
    total_examples = 0
    correct = 0
    criterion = torch.nn.CrossEntropyLoss()
    preds = None
    ys = None
    with torch.no_grad():
        iterator = tqdm(valid_loader,
                        total=len(valid_loader),
                        desc=f"Evaluate")
        for i, (x, y) in enumerate(iterator):
            if eval_device:
                x = to_device(x, args.device)
                y = to_device(y, args.device)
            model.eval()
            logits = model(x)
            loss = criterion(logits, y)
            # softmax after computing cross entropy loss
            logits = torch.softmax(logits, dim=-1)
            logits = logits.cpu().numpy()
            y = y.cpu().numpy()
            if preds is None:
                preds = logits
                ys = y
            else:
                preds = np.append(preds, logits, axis=0)
                ys = np.append(ys, y, axis=0)
            predicted = np.argmax(logits, axis=1)
            correct += np.sum(np.equal(predicted, y).astype(int))
            cur_examples = y.size
            total_loss += (loss.item() * cur_examples)
            total_examples += cur_examples
    # generate report
    labels = config['labels']
    label_names = [v for k, v in sorted(labels.items(), key=lambda x: x[0])]
    preds_ids = np.argmax(preds, axis=1)
    try:
        print(
            classification_report(ys,
                                  preds_ids,
                                  target_names=label_names,
                                  digits=4))
        print(labels)
        print(confusion_matrix(ys, preds_ids))
    except Exception as e:
        logger.warn(str(e))
    cur_loss = total_loss / total_examples
    cur_acc = correct / total_examples
    return cur_loss, cur_acc
Exemple #4
0
    def forward_core(self, input):
        batch_size = input.size(1)
        initial_h = to_device(torch.zeros(batch_size, self.hidden_size))

        # initialize eligibility vectors only on first run
        if type(self.eligibility_vectors) == type(None):
            self.initial_c = to_device(
                torch.zeros(batch_size, self.hidden_size)).requires_grad_()
            self.eligibility_vectors = [
                to_device(
                    torch.zeros(batch_size,
                                3 * self.hidden_size,
                                self.input_size,
                                requires_grad=False)),
                to_device(
                    torch.zeros(batch_size,
                                3 * self.hidden_size,
                                self.hidden_size,
                                requires_grad=False)),
                to_device(
                    torch.zeros(batch_size,
                                3 * self.hidden_size,
                                1,
                                requires_grad=False))
            ]

        initial_c = self.initial_c
        # lstm and dense pass for prediction
        lstm_out, final_c, self.eligibility_vectors = self.lstm(
            input, initial_h.detach(), initial_c, self.eligibility_vectors)
        lstm_out = self.forward_dense(lstm_out)

        # take the last output of the network to let the synth grad network predict the synthetic gradient
        synth_grad = self.synthetic_gradient_net(lstm_out[:, -1, :].detach())

        # set the gradient of the last internal state equal to the synthetic gradient
        self.initial_c, lstm_out, _ = SyntheticGradient.apply(
            final_c, lstm_out, synth_grad.detach())

        # detach initial state from the compute graph of [t_{m-1}+1, ..., t_{m}] but
        # build a new compute graph ...
        self.initial_c = self.initial_c.detach().requires_grad_()
        # ... and make sure to also detach eligibility vectors
        self.eligibility_vectors = [
            self.eligibility_vectors[0].detach(),
            self.eligibility_vectors[1].detach(),
            self.eligibility_vectors[2].detach()
        ]

        # return both the output as well as the synthetic gradient
        return lstm_out, initial_c, synth_grad
Exemple #5
0
def generate_single_lable_memory_data(num_observations, sequence_length, time_delta=None):
    '''
    Generates num_observations sequences of length sequence_length where the entire sequence is 
    filled with 0s, except for one singular signal at a random position inside the sequence.
    The label of a sequence is the singular signal value. Because the NLLLoss used to train 
    the network expects labels in the range between 0 and num_classes - 1 instead of 1 and num_classes.
    '''
    size = ((num_observations, sequence_length, 1))
    data = np.zeros(size)    
    labels = np.zeros((num_observations, 1))
    for i, row in enumerate(data):
        signal = np.random.randint(1, MEM_NUM_CLASSES, 1)
        last_possible_signal = sequence_length if not time_delta else sequence_length - time_delta
        column = np.random.randint(0, last_possible_signal, 1)
        row[column] = signal
        labels[i] = signal - 1

    return to_device(torch.from_numpy(data).float()), to_device(torch.from_numpy(labels).long())
Exemple #6
0
def evaluate(model, config, val_loader):
    opt = config['opt']
    model.eval()
    total_loss = 0.
    total_examples = 0
    correct = 0
    criterion = torch.nn.CrossEntropyLoss().to(opt.device)
    preds = None
    ys = None
    with torch.no_grad():
        for i, (x, y) in tqdm(enumerate(val_loader), total=len(val_loader)):
            x = to_device(x, opt.device)
            y = to_device(y, opt.device)
            logits = model(x)
            loss = criterion(logits, y)

            if preds is None:
                preds = to_numpy(logits)
                ys = to_numpy(y)
            else:
                preds = np.append(preds, to_numpy(logits), axis=0)
                ys = np.append(ys, to_numpy(y), axis=0)
            predicted = logits.argmax(1)
            correct += (predicted == y).sum().item()
            cur_examples = y.size(0)
            total_loss += (loss.item() * cur_examples)
            total_examples += cur_examples
    # generate report
    labels = model.labels
    label_names = [v for k, v in sorted(labels.items(), key=lambda x: x[0])]
    preds_ids = np.argmax(preds, axis=1)
    try:
        print(
            classification_report(ys,
                                  preds_ids,
                                  target_names=label_names,
                                  digits=4))
        print(labels)
        print(confusion_matrix(ys, preds_ids))
    except Exception as e:
        logger.warn(str(e))
    cur_loss = total_loss / total_examples
    cur_acc = correct / total_examples
    return cur_loss, cur_acc
Exemple #7
0
    def forward(self, input, initial_h, initial_c, eligibility_vectors=[]):
        # input (seq_len x batch_size x input_size)
        # initial_hidden (batch x hidden_size)
        # initial_state (batch x hidden_size)
        inputs = input.unbind(0)
        input_size = input.size(2)
        hidden_size = initial_h.size(1)
        batch_size = input.size(1)
        hx = initial_h
        cx = initial_c

        if len(eligibility_vectors) == 0:
            ev_w_ih_x = to_device(
                torch.zeros(batch_size,
                            3 * hidden_size,
                            input_size,
                            requires_grad=False))
            ev_w_hh_x = to_device(
                torch.zeros(batch_size,
                            3 * hidden_size,
                            hidden_size,
                            requires_grad=False))
            ev_b_x = to_device(
                torch.zeros(batch_size,
                            3 * hidden_size,
                            1,
                            requires_grad=False))
        else:
            ev_w_ih_x = eligibility_vectors[0]
            ev_w_hh_x = eligibility_vectors[1]
            ev_b_x = eligibility_vectors[2]

        forgetgate = to_device(torch.zeros(batch_size, hidden_size, 1))

        outputs = []
        for i in range(len(inputs)):
            hx, cx, ev_w_ih_x, ev_w_hh_x, ev_b_x, new_forgetgate = self.cell(
                inputs[i], hx, cx, ev_w_ih_x, ev_w_hh_x, ev_b_x, forgetgate)
            _, forgetgate = ForgetGate.apply(forgetgate, new_forgetgate)
            outputs += [hx]

        return torch.stack(outputs), cx, [ev_w_ih_x, ev_w_hh_x, ev_b_x]
def validate(data_loader, model, best_score, global_step, cfg):
    """
    Loads the model weights from the state dictionary. Function will only load
    the weights which have matching key names and dimensions in the state
    dictionary.

    :param state_dict: Pytorch model state dictionary
    :param verbose: bool, If True, the function will print the
        weight keys of parameters that can and cannot be loaded from the
        checkpoint state dictionary.
    :return:
    """
    model.eval()
    gts, predictions = [], []

    log.info("Validation started...")
    for data in data_loader:
        imgs, labels = data
        imgs = util.to_device(imgs, gpu=cfg.gpu)

        with torch.no_grad():
            logits = model(imgs)
            probs = model.module.probability(logits)
            preds = torch.argmax(probs, dim=1).cpu().numpy()

        labels = labels.cpu().detach().numpy()

        predictions.extend(preds)
        gts.extend(labels)

    predictions = np.array(predictions, dtype=np.int32)
    gts = np.array(gts, dtype=np.int32)
    acc, f1, prec, rec = util.clf_metrics(predictions=predictions,
                                          targets=gts,
                                          average="macro")
    report = classification_report(gts, predictions, output_dict=True)

    log.info("VALIDATION | Accuracy {:.4f} | F1 {:.4f} | Precision {:.4f} | "
             "Recall {:.4f}".format(acc, f1, prec, rec))

    if acc > best_score:
        save_config = {
            'name': config.name,
            'save_dir': config.ckpts_dir,
            'global_step': global_step,
            'clf_report': report
        }
        save_model(model=model, config=save_config)
        best_score = acc
    log.info("Validation end")

    model.train()
    return best_score
Exemple #9
0
 def preprocess(self, data):
     config = self.config
     opt = config['opt']
     logger.info("data: %s", data)
     text = data[0].get('data')
     if text is None:
         text = data[0].get('body')
     if text:
         text = text.decode('utf-8')
     logger.info("[Received text] %s", text)
     x = self.encode_text(text)
     x = to_device(x, opt.device)
     return x, text
Exemple #10
0
def generate_store_and_recall_data(num_observations, sequence_length, recall_repetition=0, time_delta=None):
    '''
    Generates num_observations sequences of length sequence_length. The input
    consists of three different signals:
    1. Data signal: A stream of data points with a value between 1 and SR_NUM_CLASSES + 1
    2. Store signal: Either 0 if NO storage is required, or 1 if the current data point is supposed to
                     be stored by the network
    3. Recall signal: Either 0 if NO recall is required, or 1 if the last stored data point is supposed
                      to be recalled by the network.
    The label is a sequence of length num_observations that is 0 if the recall signal is 0 or has the value
    of the last stored data point if the corresponding recall signal is 1.
    '''
    size = ((num_observations, sequence_length, 1))
    data_stream = np.random.randint(1, SR_NUM_CLASSES + 1, size)
    store_signal = np.zeros(size)
    recall_signal = np.zeros(size)
    labels = np.zeros(size)
    data = []
    time_delta = time_delta if time_delta else 1

    for i, row in enumerate(data_stream):
        # select time step where store signal is sent
        store = np.random.choice(list(range(sequence_length - time_delta - recall_repetition)))
        store_signal[i][store] = 1

        # select time step where recall signal is sent
        recall = np.random.choice(list(range(store + time_delta, sequence_length - recall_repetition)))
        
        for j in range(0, recall_repetition + 1):
            recall_signal[i][recall + j] = 1
            labels[i][recall + j] = row[store]

        data.append(np.hstack((row, store_signal[i], recall_signal[i])))

    data = np.array(data)
    return to_device(torch.from_numpy(data).float()), to_device(torch.from_numpy(labels).long())
Exemple #11
0
def chose_task(memory_task, training_algorithm):
    # Chose the task and corresponding model:
    if memory_task == config.MEMORY:
        generate_data = generate_single_lable_memory_data
        input_size = config.MEM_INPUT_SIZE
        hidden_size = config.MEM_HIDEN_SIZE 
        output_size = config.MEM_NUM_CLASSES
        single_output = True
        loss_function = nn.CrossEntropyLoss()
    elif memory_task == config.STORE_RECALL:
        generate_data = generate_store_and_recall_data
        input_size = config.SR_INPUT_SIZE
        hidden_size = config.SR_HIDEN_SIZE 
        output_size = config.SR_NUM_CLASSES + 1
        single_output = False
        loss_function = nn.CrossEntropyLoss()

    if training_algorithm == BPTT:
        model_constructor = BPTT_LSTM
        train_function = lambda model, optimizer, loss_func, batch_x, batch_y : train_bptt(model, optimizer, loss_func, batch_x, batch_y)
    elif training_algorithm == EPROP_1:
        model_constructor = EPROP1_LSTM
        train_function = lambda model, optimizer, loss_func, batch_x, batch_y : train_bptt(model, optimizer, loss_func, batch_x, batch_y)
    elif training_algorithm == EPROP_3:
        model_constructor = lambda in_size, h_size, o_size, single_output : EPROP3_LSTM(
            in_size, 
            h_size, 
            o_size, 
            single_output=single_output)
        train_function = lambda model, optimizer, loss_func, batch_x, batch_y : train_eprop3(
            model,
            optimizer, 
            loss_func, 
            batch_x, 
            batch_y, 
            config.TRUNCATION_DELTA)
    
    model = to_device(model_constructor(
            input_size,
            hidden_size,
            output_size,
            single_output=single_output))
    return generate_data, model, loss_function, train_function
Exemple #12
0
def validate(data_loader, model, best_score, global_step, cfg):
    model.eval()
    gts, predictions = [], []

    log.info("Validation started...")
    for data in data_loader:
        imgs, labels = data
        imgs = util.to_device(imgs, gpu=cfg.gpu)

        with torch.no_grad():
            logits = model(imgs)
            probs = model.module.probability(logits)
            preds = torch.argmax(probs, dim=1).cpu().numpy()

        labels = labels.cpu().detach().numpy()

        predictions.extend(preds)
        gts.extend(labels)

    predictions = np.array(predictions, dtype=np.int32)
    gts = np.array(gts, dtype=np.int32)
    acc, f1, prec, rec = util.clf_metrics(predictions=predictions,
                                          targets=gts,
                                          average="macro")
    report = classification_report(gts, predictions, output_dict=True)

    log.info("VALIDATION | Accuracy {:.4f} | F1 {:.4f} | Precision {:.4f} | "
             "Recall {:.4f}".format(acc, f1, prec, rec))

    if f1 > best_score:
        save_config = {
            'name': config.name,
            'save_dir': config.ckpts_dir,
            'global_step': global_step,
            'clf_report': report
        }
        save_model(model=model, config=save_config)
        best_score = f1
    log.info("Validation end")

    model.train()
    return best_score
Exemple #13
0
# ---------------------------------------

summ_counter = 0
mean_losses = np.zeros(3)
mean_metrics = np.zeros(5)
for epoch in tqdm(range(epoch, epoch_total),
                  initial=epoch,
                  total=epoch_total,
                  leave=False,
                  dynamic_ncols=True):
    for idx, batch in enumerate(
            tqdm(BackgroundGenerator(dataloader),
                 total=len(dataloader),
                 leave=False,
                 dynamic_ncols=True)):
        text_data, text_pos, text_len, text_mask, mel_data, mel_pos, mel_len, mel_mask, gate, text_data_ = to_device(
            batch, device)

        text_out, gate_out, att_heads_enc, att_heads_dec, att_heads = model(
            mel_data, mel_pos, mel_mask, text_data_, text_pos, text_mask)

        loss_text = F.cross_entropy(text_out.view(-1, nr_symbols),
                                    text_data.view(-1),
                                    weight=weights,
                                    ignore_index=0)
        loss_gate = F.binary_cross_entropy(gate_out, gate)
        loss = loss_text + loss_gate
        optimizer.zero_grad()
        loss.backward()
        # print(mel_mask.grad)

        grad_norm_enc = nn.utils.clip_grad_norm_(model.encoder.parameters(),
Exemple #14
0
def train_epoch(model, config, train_loader, val_loader, epoch_i,
                best_eval_f1):
    opt = config['opt']

    optimizer = config['optimizer']
    scheduler = config['scheduler']
    writer = config['writer']
    scaler = config['scaler']
    pad_label_id = config['pad_label_id']

    criterion = nn.CrossEntropyLoss(ignore_index=pad_label_id).to(opt.device)
    n_batches = len(train_loader)

    # train one epoch
    train_loss = 0.
    avg_loss = 0.
    local_best_eval_loss = float('inf')
    local_best_eval_f1 = 0
    st_time = time.time()
    optimizer.zero_grad()
    epoch_iterator = tqdm(train_loader,
                          total=len(train_loader),
                          desc=f"Epoch {epoch_i}")
    for local_step, (x, y) in enumerate(epoch_iterator):
        model.train()
        global_step = (len(train_loader) * epoch_i) + local_step
        x = to_device(x, opt.device)
        y = to_device(y, opt.device)
        if opt.use_crf:
            with autocast(enabled=opt.use_amp):
                if opt.use_profiler:
                    with profiler.profile(profile_memory=True,
                                          record_shapes=True) as prof:
                        logits, prediction = model(x)
                    print(prof.key_averages().table(
                        sort_by="self_cpu_memory_usage", row_limit=10))
                else:
                    logits, prediction = model(x)
                mask = torch.sign(torch.abs(x[0])).to(torch.uint8).to(
                    opt.device)
                log_likelihood = model.crf(logits,
                                           y,
                                           mask=mask,
                                           reduction='mean')
                loss = -1 * log_likelihood
                if opt.gradient_accumulation_steps > 1:
                    loss = loss / opt.gradient_accumulation_steps
        else:
            with autocast(enabled=opt.use_amp):
                if opt.use_profiler:
                    with profiler.profile(profile_memory=True,
                                          record_shapes=True) as prof:
                        logits = model(x)
                    print(prof.key_averages().table(
                        sort_by="self_cpu_memory_usage", row_limit=10))
                else:
                    logits = model(x)
                # reshape for computing loss
                logits_view = logits.view(-1, model.label_size)
                y_view = y.view(-1)
                loss = criterion(logits_view, y_view)
                if opt.gradient_accumulation_steps > 1:
                    loss = loss / opt.gradient_accumulation_steps
        # back-propagation - begin
        scaler.scale(loss).backward()
        if (local_step + 1) % opt.gradient_accumulation_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           opt.max_grad_norm)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            scheduler.step()
            curr_lr = scheduler.get_last_lr(
            )[0] if scheduler else optimizer.param_groups[0]['lr']
            epoch_iterator.set_description(
                f"Epoch {epoch_i}, local_step: {local_step}, loss: {loss:.3f}, curr_lr: {curr_lr:.7f}"
            )
            if opt.eval_and_save_steps > 0 and global_step != 0 and global_step % opt.eval_and_save_steps == 0:
                # evaluate
                eval_ret = evaluate(model, config, val_loader)
                eval_loss = eval_ret['loss']
                eval_f1 = eval_ret['f1']
                if local_best_eval_loss > eval_loss:
                    local_best_eval_loss = eval_loss
                if local_best_eval_f1 < eval_f1: local_best_eval_f1 = eval_f1
                if writer:
                    writer.add_scalar('Loss/valid', eval_loss, global_step)
                    writer.add_scalar('F1/valid', eval_f1, global_step)
                    writer.add_scalar('LearningRate/train', curr_lr,
                                      global_step)
                if eval_f1 > best_eval_f1:
                    best_eval_f1 = eval_f1
                    if opt.save_path and not opt.hp_search_optuna:
                        logger.info("[Best model saved] : {}, {}".format(
                            eval_loss, eval_f1))
                        save_model(config, model)
                        # save finetuned bert model/config/tokenizer
                        if config['emb_class'] not in ['glove', 'elmo']:
                            if not os.path.exists(opt.bert_output_dir):
                                os.makedirs(opt.bert_output_dir)
                            model.bert_tokenizer.save_pretrained(
                                opt.bert_output_dir)
                            model.bert_model.save_pretrained(
                                opt.bert_output_dir)
        # back-propagation - end
        train_loss += loss.item()
        if writer: writer.add_scalar('Loss/train', loss.item(), global_step)
    avg_loss = train_loss / n_batches

    # evaluate at the end of epoch
    eval_ret = evaluate(model, config, val_loader)
    eval_loss = eval_ret['loss']
    eval_f1 = eval_ret['f1']
    if local_best_eval_loss > eval_loss: local_best_eval_loss = eval_loss
    if local_best_eval_f1 < eval_f1: local_best_eval_f1 = eval_f1
    if writer:
        writer.add_scalar('Loss/valid', eval_loss, global_step)
        writer.add_scalar('F1/valid', eval_f1, global_step)
        writer.add_scalar('LearningRate/train', curr_lr, global_step)
    if eval_f1 > best_eval_f1:
        best_eval_f1 = eval_f1
        if opt.save_path and not opt.hp_search_optuna:
            logger.info("[Best model saved] : {}, {}".format(
                eval_loss, eval_f1))
            save_model(config, model)
            # save finetuned bert model/config/tokenizer
            if config['emb_class'] not in ['glove', 'elmo']:
                if not os.path.exists(opt.bert_output_dir):
                    os.makedirs(opt.bert_output_dir)
                model.bert_tokenizer.save_pretrained(opt.bert_output_dir)
                model.bert_model.save_pretrained(opt.bert_output_dir)

    curr_time = time.time()
    elapsed_time = (curr_time - st_time) / 60
    st_time = curr_time
    logs = {
        'epoch': epoch_i,
        'local_step': local_step + 1,
        'epoch_step': len(train_loader),
        'avg_loss': avg_loss,
        'local_best_eval_loss': local_best_eval_loss,
        'local_best_eval_f1': local_best_eval_f1,
        'best_eval_f1': best_eval_f1,
        'elapsed_time': elapsed_time
    }
    logger.info(json.dumps(logs, indent=4, ensure_ascii=False, sort_keys=True))

    return local_best_eval_loss, local_best_eval_f1, best_eval_f1
def prune_rewire(config, model, eval_loader, use_tqdm=True):

    args = config['opt']
    bert_model = model.bert_model

    # get the model ffn weights and biases
    inter_weights = torch.zeros(bert_model.config.num_hidden_layers, bert_model.config.intermediate_size, bert_model.config.hidden_size).to(args.device)
    inter_biases = torch.zeros(bert_model.config.num_hidden_layers, bert_model.config.intermediate_size).to(args.device)
    output_weights = torch.zeros(bert_model.config.num_hidden_layers, bert_model.config.hidden_size, bert_model.config.intermediate_size).to(args.device)

    layers = bert_model.base_model.encoder.layer
    head_importance = torch.zeros(bert_model.config.num_hidden_layers, bert_model.config.num_attention_heads).to(args.device)
    ffn_importance = torch.zeros(bert_model.config.num_hidden_layers, bert_model.config.intermediate_size).to(args.device)

    for layer_num in range(bert_model.config.num_hidden_layers):
        inter_weights[layer_num] = layers._modules[str(layer_num)].intermediate.dense.weight.detach().to(args.device)
        inter_biases[layer_num] = layers._modules[str(layer_num)].intermediate.dense.bias.detach().to(args.device)
        output_weights[layer_num] = layers._modules[str(layer_num)].output.dense.weight.detach().to(args.device)

    head_mask = torch.ones(bert_model.config.num_hidden_layers, bert_model.config.num_attention_heads, requires_grad=True).to(args.device)

    # Eval!
    logger.info(f"***** Running evaluation for pruning *****")
    logger.info("  Num batches = %d", len(eval_loader))
    logger.info("  Batch size = %d", args.eval_batch_size)
    eval_loss = 0.0
    nb_eval_steps = 0
    criterion = torch.nn.CrossEntropyLoss().to(args.device)

    eval_loader = tqdm(eval_loader, desc="Evaluating") if use_tqdm else eval_loader
    tot_tokens = 0.0
    for x, y in eval_loader:
        model.eval()
        x = to_device(x, args.device)
        y = to_device(y, args.device)
        
        logits, bert_outputs = model(x, return_bert_outputs=True, head_mask=head_mask)
        tmp_eval_loss = criterion(logits, y)

        eval_loss += tmp_eval_loss.mean().item()

        # for preventing head_mask.grad is None
        head_mask.retain_grad()

        # TODO accumulate? absolute value sum?
        tmp_eval_loss.backward()

        # collect attention confidence scores
        head_importance += head_mask.grad.abs().detach()

        # collect gradients of linear layers
        for layer_num in range(bert_model.config.num_hidden_layers):
            ffn_importance[layer_num] += torch.abs(
                torch.sum(layers._modules[str(layer_num)].intermediate.dense.weight.grad.detach()*inter_weights[layer_num], 1) 
                + layers._modules[str(layer_num)].intermediate.dense.bias.grad.detach()*inter_biases[layer_num])
 
        attention_mask = x[1]
        tot_tokens += attention_mask.float().detach().sum().data
        nb_eval_steps += 1

    head_importance /= tot_tokens

    # Layerwise importance normalization
    if not args.dont_normalize_importance_by_layer:
        exponent = 2
        norm_by_layer = torch.pow(torch.pow(head_importance, exponent).sum(-1), 1 / exponent)
        head_importance /= norm_by_layer.unsqueeze(-1) + 1e-20

    # rewire the network
    head_importance = head_importance.cpu()
    ffn_importance = ffn_importance.cpu()
    num_heads = bert_model.config.num_attention_heads
    head_size = bert_model.config.hidden_size / num_heads
    for layer_num in range(bert_model.config.num_hidden_layers):
        # load query, key, value weights
        query_weight = layers._modules[str(layer_num)].attention.self.query.weight
        query_bias = layers._modules[str(layer_num)].attention.self.query.bias
        key_weight = layers._modules[str(layer_num)].attention.self.key.weight
        key_bias = layers._modules[str(layer_num)].attention.self.key.bias
        value_weight = layers._modules[str(layer_num)].attention.self.value.weight
        value_bias = layers._modules[str(layer_num)].attention.self.value.bias

        # sort query, key, value based on the confidence scores
        query_weight, query_bias = sort_by_importance(query_weight,
            query_bias,
            head_importance[layer_num],
            args.target_num_heads,
            head_size)
        print('query_weight = ', query_weight.shape)
        print('query_bias = ', query_bias.shape)
        layers._modules[str(layer_num)].attention.self.query.weight = torch.nn.Parameter(query_weight)
        layers._modules[str(layer_num)].attention.self.query.bias = torch.nn.Parameter(query_bias)
        key_weight, key_bias = sort_by_importance(key_weight,
            key_bias,
            head_importance[layer_num],
            args.target_num_heads,
            head_size)
        print('key_weight = ', key_weight.shape)
        print('key_bias = ', key_bias.shape)
        layers._modules[str(layer_num)].attention.self.key.weight = torch.nn.Parameter(key_weight)
        layers._modules[str(layer_num)].attention.self.key.bias = torch.nn.Parameter(key_bias)
        value_weight, value_bias = sort_by_importance(value_weight,
            value_bias,
            head_importance[layer_num],
            args.target_num_heads,
            head_size)
        print('value_weight = ', value_weight.shape)
        print('value_bias = ', value_bias.shape)
        layers._modules[str(layer_num)].attention.self.value.weight = torch.nn.Parameter(value_weight)
        layers._modules[str(layer_num)].attention.self.value.bias = torch.nn.Parameter(value_bias)

        # output matrix
        weight_sorted, _ = sort_by_importance(
            layers._modules[str(layer_num)].attention.output.dense.weight.transpose(0, 1),
            None,
            head_importance[layer_num],
            args.target_num_heads,
            head_size)
        weight_sorted = weight_sorted.transpose(0, 1)
        print('attention.output.dense.weight = ', weight_sorted.shape)
        layers._modules[str(layer_num)].attention.output.dense.weight = torch.nn.Parameter(weight_sorted)

        weight_sorted, bias_sorted = sort_by_importance(
            layers._modules[str(layer_num)].intermediate.dense.weight,
            layers._modules[str(layer_num)].intermediate.dense.bias, 
            ffn_importance[layer_num],
            args.target_ffn_dim,
            1)
        layers._modules[str(layer_num)].intermediate.dense.weight = torch.nn.Parameter(weight_sorted)
        layers._modules[str(layer_num)].intermediate.dense.bias = torch.nn.Parameter(bias_sorted)

        # ffn output matrix input side
        weight_sorted, _ = sort_by_importance(
            layers._modules[str(layer_num)].output.dense.weight.transpose(0, 1),
            None, 
            ffn_importance[layer_num],
            args.target_ffn_dim,
            1)
        weight_sorted = weight_sorted.transpose(0, 1)
        print('output.dense.weight = ', weight_sorted.shape)
        layers._modules[str(layer_num)].output.dense.weight = torch.nn.Parameter(weight_sorted)

    # set bert model's config for pruned model
    bert_model.config.num_attention_heads = min([num_heads, args.target_num_heads])
    bert_model.config.intermediate_size = layers._modules['0'].intermediate.dense.weight.size(0)
Exemple #16
0
def train_epoch(model, config, train_loader, val_loader, epoch_i):
    optimizer = config['optimizer']
    scheduler = config['scheduler']
    writer = config['writer']
    scaler = config['scaler']
    opt = config['opt']

    if opt.criterion == 'MSELoss':
        criterion = torch.nn.MSELoss(reduction='sum').to(opt.device)
    else:
        criterion = torch.nn.CrossEntropyLoss().to(opt.device)

    # train one epoch
    model.train()
    total_loss = 0.
    final_val_loss = 0.
    total_examples = 0
    st_time = time.time()
    optimizer.zero_grad()
    for local_step, (x,y) in tqdm(enumerate(train_loader), total=len(train_loader)):
        global_step = (len(train_loader) * epoch_i) + local_step
        x = to_device(x, opt.device)
        y = to_device(y, opt.device)
        with autocast(enabled=opt.use_amp):
            if opt.use_profiler:
                with profiler.profile(profile_memory=True, record_shapes=True) as prof:
                    output = model(x)
                print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))
            else:
                output = model(x)
            loss = criterion(output, y)
            if opt.gradient_accumulation_steps > 1:
                loss = loss / opt.gradient_accumulation_steps
        # back-propagation - begin
        scaler.scale(loss).backward()
        if (local_step + 1) % opt.gradient_accumulation_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), opt.max_grad_norm)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            if opt.use_transformers_optimizer: scheduler.step()
        # back-propagation - end
        cur_examples = y.size(0)
        total_examples += cur_examples
        total_loss += (loss.item() * cur_examples)
        if writer: writer.add_scalar('Loss/train', loss.item(), global_step)
    cur_loss = total_loss / total_examples

    # evaluate
    eval_loss, eval_acc = evaluate(model, config, val_loader)
    curr_time = time.time()
    elapsed_time = (curr_time - st_time) / 60
    st_time = curr_time
    curr_lr = scheduler.get_last_lr()[0] if scheduler else optimizer.param_groups[0]['lr']
    logger.info('{:3d} epoch | {:5d}/{:5d} | train loss : {:6.3f}, valid loss {:6.3f}, valid acc {:.4f}| lr :{:7.6f} | {:5.2f} min elapsed'.\
            format(epoch_i, local_step+1, len(train_loader), cur_loss, eval_loss, eval_acc, curr_lr, elapsed_time)) 
    if writer:
        writer.add_scalar('Loss/valid', eval_loss, global_step)
        writer.add_scalar('Acc/valid', eval_acc, global_step)
        writer.add_scalar('LearningRate/train', curr_lr, global_step)
    return eval_loss, eval_acc
Exemple #17
0
def evaluate(args):
    # set config
    config = load_config(args)
    if args.num_threads > 0: torch.set_num_threads(args.num_threads)
    config['args'] = args
    logger.info("%s", config)

    # set path
    set_path(config)

    # prepare test dataset
    test_loader = prepare_datasets(config)

    # load pytorch model checkpoint
    checkpoint = load_checkpoint(args.model_path, device=args.device)

    # prepare model and load parameters
    model = load_model(config, checkpoint)
    model.eval()

    # convert to onnx format
    if args.convert_onnx:
        # FIXME not working for --use_crf
        batch = next(iter(test_loader))
        if config['emb_class'] not in ['glove', 'elmo']:
            x, y, gy = batch
        else:
            x, y = batch
        x = to_device(x, args.device)
        convert_onnx(config, model, x)
        check_onnx(config)
        logger.info("[ONNX model saved] : {}".format(args.onnx_path))
        # quantize onnx
        if args.quantize_onnx:
            quantize_onnx(args.onnx_path, args.quantized_onnx_path)
            logger.info("[Quantized ONNX model saved] : {}".format(
                args.quantized_onnx_path))
        return

    # load onnx model for using onnxruntime
    if args.enable_ort:
        import onnxruntime as ort
        sess_options = ort.SessionOptions()
        sess_options.inter_op_num_threads = args.num_threads
        sess_options.intra_op_num_threads = args.num_threads
        ort_session = ort.InferenceSession(args.onnx_path,
                                           sess_options=sess_options)

    # enable to use dynamic quantized model (pytorch>=1.3.0)
    if args.enable_dqm:
        model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear},
                                                    dtype=torch.qint8)
        print(model)

    # evaluation
    preds = None
    ys = None
    gpreds = None
    gys = None
    n_batches = len(test_loader)
    total_examples = 0
    whole_st_time = time.time()
    first_time = time.time()
    first_examples = 0
    total_duration_time = 0.0
    with torch.no_grad():
        for i, batch in enumerate(tqdm(test_loader, total=n_batches)):
            start_time = time.time()
            if config['emb_class'] not in ['glove', 'elmo']:
                x, y, gy = batch
                gy = to_device(gy, args.device)
            else:
                x, y = batch

            x = to_device(x, args.device)
            y = to_device(y, args.device)
            if args.enable_ort:
                ort_inputs = build_onnx_input(config, ort_session, x)
                if args.use_crf:
                    # FIXME not working for --use_crf
                    if args.bert_use_mtl:
                        logits, prediction, glogits = ort_session.run(
                            None, ort_inputs)
                        glogits = to_device(torch.tensor(glogits), args.device)
                    else:
                        logits, prediction = ort_session.run(None, ort_inputs)
                    prediction = to_device(torch.tensor(prediction),
                                           args.device)
                    logits = to_device(torch.tensor(logits), args.device)
                else:
                    if args.bert_use_mtl:
                        logits, glogits = ort_session.run(None, ort_inputs)
                        glogits = to_device(torch.tensor(glogits), args.device)
                    else:
                        logits = ort_session.run(None, ort_inputs)[0]
                    logits = to_device(torch.tensor(logits), args.device)
                    logits = torch.softmax(logits, dim=-1)
            else:
                if args.use_crf:
                    if args.bert_use_mtl:
                        logits, prediction, glogits = model(x)
                    else:
                        logits, prediction = model(x)
                else:
                    if args.bert_use_mtl:
                        logits, glogits = model(x)
                    else:
                        logits = model(x)
                    logits = torch.softmax(logits, dim=-1)

            if preds is None:
                if args.use_crf: preds = to_numpy(prediction)
                else: preds = to_numpy(logits)
                ys = to_numpy(y)
            else:
                if args.use_crf:
                    preds = np.append(preds, to_numpy(prediction), axis=0)
                else:
                    preds = np.append(preds, to_numpy(logits), axis=0)
                ys = np.append(ys, to_numpy(y), axis=0)

            if args.bert_use_mtl:
                glogits = torch.softmax(glogits, dim=-1)
                if gpreds is None:
                    gpreds = to_numpy(glogits)
                    gys = to_numpy(gy)
                else:
                    gpreds = np.append(gpreds, to_numpy(glogits), axis=0)
                    gys = np.append(gys, to_numpy(gy), axis=0)

            cur_examples = y.size(0)
            total_examples += cur_examples
            if i == 0:  # first one may take longer time, so ignore in computing duration.
                first_time = float((time.time() - first_time) * 1000)
                first_examples = cur_examples
            if args.num_examples != 0 and total_examples >= args.num_examples:
                logger.info("[Stop Evaluation] : up to the {} examples".format(
                    total_examples))
                break
            duration_time = float((time.time() - start_time) * 1000)
            if i != 0: total_duration_time += duration_time
            '''
            logger.info("[Elapsed Time] : {}ms".format(duration_time))
            '''
    whole_time = float((time.time() - whole_st_time) * 1000)
    avg_time = (whole_time - first_time) / (total_examples - first_examples)

    # generate report for token classification
    if not args.use_crf: preds = np.argmax(preds, axis=2)
    # compute measure using seqeval
    labels = config['labels']
    ys_lbs = [[] for _ in range(ys.shape[0])]
    preds_lbs = [[] for _ in range(ys.shape[0])]
    pad_label_id = config['pad_label_id']
    for i in range(ys.shape[0]):  # foreach sentence
        for j in range(ys.shape[1]):  # foreach token
            if ys[i][j] != pad_label_id:
                ys_lbs[i].append(labels[ys[i][j]])
                preds_lbs[i].append(labels[preds[i][j]])
    ret = {
        "precision": precision_score(ys_lbs, preds_lbs),
        "recall": recall_score(ys_lbs, preds_lbs),
        "f1": f1_score(ys_lbs, preds_lbs),
        "report": classification_report(ys_lbs, preds_lbs, digits=4),
    }
    print(ret['report'])
    # write predicted labels to file
    write_prediction(config, model, ys, preds, labels)

    # generate report for sequence classification
    if args.bert_use_mtl:
        glabels = config['glabels']
        glabel_names = [
            v for k, v in sorted(glabels.items(), key=lambda x: x[0])
        ]
        glabel_ids = [k for k in glabels.keys()]
        gpreds_ids = np.argmax(gpreds, axis=1)
        try:
            g_report = sequence_classification_report(
                gys,
                gpreds_ids,
                target_names=glabel_names,
                labels=glabel_ids,
                digits=4)
            g_report_dict = sequence_classification_report(
                gys,
                gpreds_ids,
                target_names=glabel_names,
                labels=glabel_ids,
                output_dict=True)
            g_matrix = confusion_matrix(gys, gpreds_ids)
            ret['g_report'] = g_report
            ret['g_report_dict'] = g_report_dict
            ret['g_f1'] = g_report_dict['micro avg']['f1-score']
            ret['g_matrix'] = g_matrix
        except Exception as e:
            logger.warn(str(e))
        print(ret['g_report'])
        print(ret['g_f1'])
        print(ret['g_matrix'])
        logger.info("[sequence classification F1] : {}, {}".format(
            ret['g_f1'], total_examples))
        # write predicted glabels to file
        write_gprediction(args, gpreds, glabels)

    logger.info("[token classification F1] : {}, {}".format(
        ret['f1'], total_examples))
    logger.info("[Elapsed Time] : {} examples, {}ms, {}ms on average".format(
        total_examples, whole_time, avg_time))
    logger.info(
        "[Elapsed Time(total_duration_time, average)] : {}ms, {}ms".format(
            total_duration_time, total_duration_time / (total_examples - 1)))
Exemple #18
0
                        num_workers=0,
                        drop_last=True)
writer = tensorboard.SummaryWriter(log_dir=f'test/{logs_idx}')
if len(saves) != 0:
    saves.sort(key=os.path.getmtime)
    checkpoint = torch.load(saves[-1], )
    text_embedding.load_state_dict(checkpoint['text_embedding'])
    text_embedding.eval()
    encoder.load_state_dict(checkpoint['encoder'])
    encoder.eval()
    decoder.load_state_dict(checkpoint['decoder'])
    decoder.eval()

for idx, batch in enumerate(
        tqdm(BackgroundGenerator(dataloader), total=len(dataloader))):
    text_data, text_pos, text_mask, mel_data, mel_pos, mel_mask, gate = to_device(
        batch, device)
    # audio_data = F.avg_pool1d(audio_data, kernel_size=2, padding=1)
    with torch.no_grad():
        text_emb = text_embedding(text_data)
        text_pos_emb = pos_embedding_(text_pos)
        enc_out, att_heads_enc = encoder(text_emb, text_mask, text_pos_emb)

        mel_pos = torch.arange(1, 512).view(1, 511).to(device)
        mel_pos_emb_ = pos_embedding(mel_pos)
        mel_mask_ = torch.triu(torch.ones(511, 511, dtype=torch.bool),
                               1).unsqueeze(0).to(device)
        # [B, T, C], [B, T, C], [B, T, 1], [B, T, T_text]
        mel = torch.zeros(1, 511, 80).to(device)
        for pos_idx in tqdm(range(511)):
            mel_pos_emb = mel_pos_emb_[:, :pos_idx + 1]
            mel_mask = mel_mask_[:, :pos_idx + 1, :pos_idx + 1]
Exemple #19
0
def train_epoch(model, config, train_loader, val_loader, epoch_i):
    opt = config['opt']

    optimizer = config['optimizer']
    scheduler = config['scheduler']
    writer = config['writer']
    scaler = config['scaler']
    pad_label_id = config['pad_label_id']

    criterion = nn.CrossEntropyLoss(ignore_index=pad_label_id).to(opt.device)
    n_batches = len(train_loader)
    prog = Progbar(target=n_batches)
    # train one epoch
    model.train()
    train_loss = 0.
    st_time = time.time()
    optimizer.zero_grad()
    for local_step, (x,y) in enumerate(train_loader):
        global_step = (len(train_loader) * epoch_i) + local_step
        x = to_device(x, opt.device)
        y = to_device(y, opt.device)
        if opt.use_crf:
            with autocast(enabled=opt.use_amp):
                if opt.use_profiler:
                    with profiler.profile(profile_memory=True, record_shapes=True) as prof:
                        logits, prediction = model(x)
                    print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))
                else:
                    logits, prediction = model(x)
                mask = torch.sign(torch.abs(x[0])).to(torch.uint8).to(opt.device)
                log_likelihood = model.crf(logits, y, mask=mask, reduction='mean')
                loss = -1 * log_likelihood
                if opt.gradient_accumulation_steps > 1:
                    loss = loss / opt.gradient_accumulation_steps
        else:
            with autocast(enabled=opt.use_amp):
                if opt.use_profiler:
                    with profiler.profile(profile_memory=True, record_shapes=True) as prof:
                        logits = model(x)
                    print(prof.key_averages().table(sort_by="self_cpu_memory_usage", row_limit=10))
                else:
                    logits = model(x)
                # reshape for computing loss
                logits_view = logits.view(-1, model.label_size)
                y_view = y.view(-1)
                loss = criterion(logits_view, y_view)
                if opt.gradient_accumulation_steps > 1:
                    loss = loss / opt.gradient_accumulation_steps
        # back-propagation - begin
        scaler.scale(loss).backward()
        if (local_step + 1) % opt.gradient_accumulation_steps == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), opt.max_grad_norm)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            if opt.use_transformers_optimizer: scheduler.step()
        # back-propagation - end
        train_loss += loss.item()
        if writer: writer.add_scalar('Loss/train', loss.item(), global_step)
        curr_lr = scheduler.get_last_lr()[0] if scheduler else optimizer.param_groups[0]['lr']
        prog.update(local_step+1,
                    [('global step', global_step),
                     ('train curr loss', loss.item()),
                     ('lr', curr_lr)])
    train_loss = train_loss / n_batches

    # evaluate
    eval_ret = evaluate(model, config, val_loader)
    eval_loss = eval_ret['loss']
    eval_f1 = eval_ret['f1']
    curr_time = time.time()
    elapsed_time = (curr_time - st_time) / 60
    st_time = curr_time
    curr_lr = scheduler.get_last_lr()[0] if scheduler else optimizer.param_groups[0]['lr']
    logger.info('{:3d} epoch | {:5d}/{:5d} | train loss : {:10.6f}, valid loss {:10.6f}, valid f1 {:.4f}| lr :{:7.6f} | {:5.2f} min elapsed'.\
            format(epoch_i, local_step+1, len(train_loader), train_loss, eval_loss, eval_f1, curr_lr, elapsed_time)) 
    if writer:
        writer.add_scalar('Loss/valid', eval_loss, global_step)
        writer.add_scalar('F1/valid', eval_f1, global_step)
        writer.add_scalar('LearningRate/train', curr_lr, global_step)
    return eval_loss, eval_f1
def train_eprop3(model, optimizer, loss_function, batch_x, batch_y,
                 truncation_delta):
    seq_len = batch_x.shape[1]
    loss_function_synth = nn.MSELoss()

    # reset model
    model.reset()

    # implementation of algo on page 33 of eprop paper
    for i, start in enumerate(
            range(truncation_delta, seq_len, truncation_delta)):
        # reset gradient
        optimizer.zero_grad()

        # select [t_{m-1}+1, ..., t_{m}]
        first_batch_x = batch_x[:, start - truncation_delta:start, :].clone()
        first_batch_y = batch_y[:, start - truncation_delta:start, :].clone()

        # simulate network over [t_{m-1}+1, ..., t_{m}] and backprop using the synthetic gradient
        prediction, _, first_synth_grad = model(first_batch_x)

        pred, gt = format_pred_and_gt(prediction, first_batch_y,
                                      model.single_output)
        loss = loss_function(pred, gt)
        loss.backward()

        # select [t_{m}+1, ..., t_{m+1}] (the next truncated time interval)
        second_batch_x = batch_x[:, start:start + truncation_delta, :].clone()
        second_batch_y = batch_y[:, start:start + truncation_delta, :].clone()

        # simulate and backprop using second synthetic gradient
        prediction, second_initial_state, second_synth_grad = model(
            second_batch_x)

        # retain grad of the initial hidden state of the second interval ...
        second_initial_state.retain_grad()

        pred, gt = format_pred_and_gt(prediction, second_batch_y,
                                      model.single_output)
        loss = loss_function(pred, gt)
        loss.backward()

        # ... and store it ...
        real_grad_x = second_initial_state.grad.detach()

        # ... to optimize the synth grad network using MSE
        loss = loss_function_synth(first_synth_grad, real_grad_x)
        loss.backward()

        real_grad_x_shape = real_grad_x.shape

        # train the final synthetic gradient to be close to 0
        if start + truncation_delta == seq_len:
            zeros = to_device(
                torch.zeros(real_grad_x_shape, requires_grad=False))
            loss = loss_function_synth(second_synth_grad, zeros)
            loss.backward()

        optimizer.step()

    with torch.no_grad():
        prediction, _, _ = model(batch_x)
        pred, gt = format_pred_and_gt(prediction, batch_y, model.single_output)
        loss = loss_function(pred, gt)

    return loss.item()
Exemple #21
0
def train_epoch(model, config, train_loader, val_loader, epoch_i):
    optimizer = config['optimizer']
    scheduler = config['scheduler']
    writer = config['writer']
    opt = config['opt']

    local_rank = opt.local_rank
    use_amp = opt.use_amp
    if opt.criterion == 'MSELoss':
        criterion = torch.nn.MSELoss(reduction='sum').to(opt.device)
    else:
        criterion = torch.nn.CrossEntropyLoss().to(opt.device)

    # train one epoch
    model.train()
    optimizer.zero_grad()
    total_loss = 0.
    final_val_loss = 0.
    total_examples = 0
    st_time = time.time()
    for local_step, (x, y) in tqdm(enumerate(train_loader),
                                   total=len(train_loader)):
        global_step = (len(train_loader) * epoch_i) + local_step
        x = to_device(x, opt.device)
        y = to_device(y, opt.device)
        output = model(x)
        loss = criterion(output, y)
        # back-propagation - begin
        if opt.gradient_accumulation_steps > 1:
            loss = loss / opt.gradient_accumulation_steps
        if use_amp:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                try:
                    scaled_loss.backward()
                except Exception as e:
                    print(e)
        else:
            loss.backward()
        if (local_step + 1) % opt.gradient_accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           opt.max_grad_norm)
            optimizer.step()
            if opt.use_transformers_optimizer: scheduler.step()
            optimizer.zero_grad()
        # back-propagation - end
        cur_examples = y.size(0)
        total_examples += cur_examples
        total_loss += (loss.item() * cur_examples)
        if local_rank == 0 and writer:
            writer.add_scalar('Loss/train', loss.item(), global_step)
    cur_loss = total_loss / total_examples

    # evaluate
    eval_loss, eval_acc = evaluate(model, config, val_loader)
    curr_time = time.time()
    elapsed_time = (curr_time - st_time) / 60
    st_time = curr_time
    curr_lr = scheduler.get_lr(
    )[0] if scheduler else optimizer.param_groups[0]['lr']
    if local_rank == 0:
        logger.info('{:3d} epoch | {:5d}/{:5d} | train loss : {:6.3f}, valid loss {:6.3f}, valid acc {:.4f}| lr :{:7.6f} | {:5.2f} min elapsed'.\
                format(epoch_i, local_step+1, len(train_loader), cur_loss, eval_loss, eval_acc, curr_lr, elapsed_time))
        if writer:
            writer.add_scalar('Loss/valid', eval_loss, global_step)
            writer.add_scalar('Acc/valid', eval_acc, global_step)
            writer.add_scalar('LearningRate/train', curr_lr, global_step)
    return eval_loss, eval_acc
def distill(
        teacher_config,
        teacher_model,
        student_config,
        student_model,
        train_loader,
        eval_loader,
        best_eval_metric=None,
        mpl_loader=None):

    args = teacher_config['opt']

    teacher_layer_num = teacher_model.bert_model.config.num_hidden_layers
    student_layer_num = student_model.bert_model.config.num_hidden_layers

    # create teacher optimizer with larger L2 norm
    teacher_optimizer, _, _, _ = prepare_osws(teacher_config, teacher_model, train_loader, lr=args.mpl_learning_rate, weight_decay=args.mpl_weight_decay)

    # create student optimizer, scheduler, summary writer
    student_optimizer, student_scheduler, writer, _ = prepare_osws(student_config, student_model, train_loader, lr=args.lr, weight_decay=args.weight_decay)

    # prepare loss functions
    def soft_cross_entropy(predicts, targets):
        likelihood = F.log_softmax(predicts, dim=-1)
        targets_prob = F.softmax(targets, dim=-1)
        return (- targets_prob * likelihood).sum(dim=-1).mean()

    loss_mse_sum = MSELoss(reduction='sum').to(args.device)
    loss_mse = MSELoss().to(args.device)
    loss_cs = CosineSimilarity(dim=2).to(args.device)
    loss_cs_att = CosineSimilarity(dim=3).to(args.device)

    logger.info("***** Running distillation training *****")
    logger.info("  Num Batchs = %d", len(train_loader))
    logger.info("  Num Epochs = %d", args.epoch)
    logger.info("  batch size = %d", args.batch_size)
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    tr_loss, logging_loss = 0.0, 0.0
    tr_att_loss = 0.
    tr_rep_loss = 0.
    tr_cls_loss = 0.
    teacher_model.zero_grad()
    student_model.zero_grad()
    epoch_iterator = range(epochs_trained, int(args.epoch))

    # for reproductibility
    set_seed(args)

    for epoch_n in epoch_iterator:
        tr_att_loss = 0.
        tr_rep_loss = 0.
        tr_cls_loss = 0.
        train_iterator = tqdm(train_loader, desc=f"Epoch {epoch_n}")
        for step, (x, y) in enumerate(train_iterator):
            x = to_device(x, args.device)
            y = to_device(y, args.device)

            # -------------------------------------------------------------------------------------------------------
            # teacher -> student, teaching with teacher_model.eval(), student_model.train()
            # -------------------------------------------------------------------------------------------------------
            att_loss = 0.
            rep_loss = 0.
            cls_loss = 0.

            # teacher model output
            teacher_model.eval()
            with torch.no_grad():
                output_teacher, teacher_bert_outputs = teacher_model(x, return_bert_outputs=True)

            # student model output
            student_model.train()
            output_student, student_bert_outputs = student_model(x, return_bert_outputs=True)

           
            # Knowledge Distillation loss
            # 1) logits distillation
            '''
            kd_loss = soft_cross_entropy(output_student, output_teacher)
            '''
            kd_loss = loss_mse_sum(output_student, output_teacher)

            loss = kd_loss
            tr_cls_loss += loss.item()

            # 2) embedding and last hidden state distillation
            if args.state_loss_ratio > 0.0:
                teacher_reps = teacher_bert_outputs.hidden_states
                student_reps = student_bert_outputs.hidden_states

                new_teacher_reps = [teacher_reps[0], teacher_reps[teacher_layer_num]]
                new_student_reps = [student_reps[0], student_reps[student_layer_num]]
                for student_rep, teacher_rep in zip(new_student_reps, new_teacher_reps):
                    # cosine similarity loss
                    if args.state_distill_cs:
                        tmp_loss = 1.0 - loss_cs(student_rep, teacher_rep).mean()
                    # MSE loss
                    else:
                        tmp_loss = loss_mse(student_rep, teacher_rep)
                    rep_loss += tmp_loss
                loss += args.state_loss_ratio * rep_loss
                tr_rep_loss += rep_loss.item()

            # 3) Attentions distillation
            if args.att_loss_ratio > 0.0:
                teacher_atts = teacher_bert_outputs.attentions
                student_atts = student_bert_outputs.attentions

                assert teacher_layer_num == len(teacher_atts)
                assert student_layer_num == len(student_atts)
                assert teacher_layer_num % student_layer_num == 0
                layers_per_block = int(teacher_layer_num / student_layer_num)
                new_teacher_atts = [teacher_atts[i * layers_per_block + layers_per_block - 1]
                                    for i in range(student_layer_num)]

                for student_att, teacher_att in zip(student_atts, new_teacher_atts):
                    student_att = torch.where(student_att <= -1e2, torch.zeros_like(student_att).to(args.device),
                                              student_att)
                    teacher_att = torch.where(teacher_att <= -1e2, torch.zeros_like(teacher_att).to(args.device),
                                              teacher_att)
                    tmp_loss = 1.0 - loss_cs_att(student_att, teacher_att).mean()
                    att_loss += tmp_loss

                loss += args.att_loss_ratio * att_loss
                tr_att_loss += att_loss.item()

            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            # back propagate through student model
            loss.backward()
            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(student_model.parameters(), args.max_grad_norm)
                student_optimizer.step()  # update student model
                student_scheduler.step()  # Update learning rate schedule
                student_model.zero_grad()
                global_step += 1
            # -------------------------------------------------------------------------------------------------------

            # -------------------------------------------------------------------------------------------------------
            # student -> teacher, performance feedback/update with student_model.eval(), teacher_model.train()
            # -------------------------------------------------------------------------------------------------------
            mpl_loss = 0.0
            if mpl_loader and global_step > args.mpl_warmup_steps: 
                loss_cross_entropy = torch.nn.CrossEntropyLoss().to(args.device)
                mpl_iterator = iter(mpl_loader)
                try:
                    (x, y) = next(mpl_iterator) # draw random sample
                except StopIteration as e:
                    mpl_iterator = iter(mpl_loader)
                    (x, y) = next(mpl_iterator) # draw random sample
                x = to_device(x, args.device)
                y = to_device(y, args.device)

                # teacher model output
                teacher_model.train()
                output_teacher, teacher_bert_outputs = teacher_model(x, return_bert_outputs=True)

                # student model output
                student_model.eval() # updated student model
                output_student, student_bert_outputs = student_model(x, return_bert_outputs=True)

                # the loss is the performance of the student on the labeled data.
                # additionaly, we add the loss of the teacher on the labeled data for avoiding overfitting.
                mpl_loss = loss_cross_entropy(output_student, y) / 2 + loss_cross_entropy(output_teacher, y) / 2
                if args.gradient_accumulation_steps > 1:
                    mpl_loss = mpl_loss / args.gradient_accumulation_steps

                # back propagate through teacher model
                mpl_loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    torch.nn.utils.clip_grad_norm_(teacher_model.parameters(), args.max_grad_norm)
                    teacher_optimizer.step() # update teacher model 
                    teacher_model.zero_grad()
                    student_model.zero_grad() # clear gradient info which was generated during forward computation.
            # -------------------------------------------------------------------------------------------------------

            train_iterator.set_description(f"Epoch {epoch_n} loss: {loss:.3f}, mpl loss: {mpl_loss:.3f}")
            if writer:
                writer.add_scalar('loss', loss, global_step)
                writer.add_scalar('mpl_loss', mpl_loss, global_step)


            # -------------------------------------------------------------------------------------------------------
            # evaluate student, save model
            # -------------------------------------------------------------------------------------------------------
            flag_eval = False
            logs = {}
            if args.logging_steps > 0 and global_step % args.logging_steps == 0: flag_eval = True
            if flag_eval:
                if args.log_evaluate_during_training:
                    eval_loss, eval_acc = evaluate(student_model, student_config, eval_loader)
                    logs['eval_loss'] = eval_loss
                    logs['eval_acc'] = eval_acc
                    if writer:
                        writer.add_scalar('eval_loss', eval_loss, global_step)
                        writer.add_scalar('eval_acc', eval_acc, global_step)
                
                cls_loss = tr_cls_loss / (step + 1)
                att_loss = tr_att_loss / (step + 1)
                rep_loss = tr_rep_loss / (step + 1)

                loss_scalar = (tr_loss - logging_loss) / args.logging_steps
                learning_rate_scalar = student_scheduler.get_last_lr()[0]
                logs["learning_rate"] = learning_rate_scalar
                logs["avg_loss_since_last_log"] = loss_scalar
                logs['cls_loss'] = cls_loss
                logs['att_loss'] = att_loss
                logs['rep_loss'] = rep_loss
                logging_loss = tr_loss
                logging.info(json.dumps({**logs, **{"step": global_step}}))
                if writer:
                    writer.add_scalar('learning_rate', learning_rate_scalar, global_step)
                    writer.add_scalar('avg_loss_since_last_log', loss_scalar, global_step)
                    writer.add_scalar('cls_loss', cls_loss, global_step)
                    writer.add_scalar('att_loss', att_loss, global_step)
                    writer.add_scalar('rep_loss', rep_loss, global_step)

            flag_eval = False
            if step == 0 and epoch_n != 0: flag_eval = True # every epoch
            if args.eval_and_save_steps > 0 and global_step % args.eval_and_save_steps == 0: flag_eval = True
            if flag_eval:
                eval_loss, eval_acc = evaluate(student_model, student_config, eval_loader)
                logs['eval_loss'] = eval_loss
                logs['eval_acc'] = eval_acc
                logger.info(json.dumps({**logs, **{"step": global_step}}))
                if writer:
                    writer.add_scalar('eval_loss', eval_loss, global_step)
                    writer.add_scalar('eval_acc', eval_acc, global_step)
                # measured by accuracy
                curr_eval_metric = eval_acc
                if best_eval_metric is None or curr_eval_metric > best_eval_metric:
                    # save model to '--save_path', '--bert_output_dir'
                    save_model(student_config, student_model, save_path=args.save_path)
                    student_model.bert_tokenizer.save_pretrained(args.bert_output_dir)
                    student_model.bert_model.save_pretrained(args.bert_output_dir)
                    best_eval_metric = curr_eval_metric
                    logger.info("[Best student model saved] : {:10.6f}, {}, {}".format(best_eval_metric, args.bert_output_dir, args.save_path))
            # -------------------------------------------------------------------------------------------------------

    return global_step, tr_loss / global_step, best_eval_metric
Exemple #23
0
def evaluate(opt):
    # set config
    config = load_config(opt)
    if opt.num_threads > 0: torch.set_num_threads(opt.num_threads)
    config['opt'] = opt
    logger.info("%s", config)

    # set path
    set_path(config)

    # prepare test dataset
    test_loader = prepare_datasets(config)

    # load pytorch model checkpoint
    checkpoint = load_checkpoint(opt.model_path, device=opt.device)

    # prepare model and load parameters
    model = load_model(config, checkpoint)
    model.eval()

    # convert to onnx
    if opt.convert_onnx:
        (x, y) = next(iter(test_loader))
        x = to_device(x, opt.device)
        y = to_device(y, opt.device)
        convert_onnx(config, model, x)
        check_onnx(config)
        logger.info("[ONNX model saved] :{}".format(opt.onnx_path))
        # quantize onnx
        if opt.quantize_onnx:
            quantize_onnx(opt.onnx_path, opt.quantized_onnx_path)
            logger.info("[Quantized ONNX model saved] : {}".format(
                opt.quantized_onnx_path))
        return

    # load onnx model for using onnxruntime
    if opt.enable_ort:
        import onnxruntime as ort
        sess_options = ort.SessionOptions()
        sess_options.inter_op_num_threads = opt.num_threads
        sess_options.intra_op_num_threads = opt.num_threads
        ort_session = ort.InferenceSession(opt.onnx_path,
                                           sess_options=sess_options)

    # convert to tvm format
    if opt.convert_tvm:
        (x, y) = next(iter(test_loader))
        x = to_device(x, opt.device)
        y = to_device(y, opt.device)
        convert_tvm(config, model, x)
        logger.info("[TVM model saved] : {}".format(opt.tvm_path))
        return

    # enable to use dynamic quantized model (pytorch>=1.3.0)
    if opt.enable_dqm and opt.device == 'cpu':
        model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear},
                                                    dtype=torch.qint8)
        print(model)

    # evaluation
    preds = None
    ys = None
    correct = 0
    n_batches = len(test_loader)
    total_examples = 0
    whole_st_time = time.time()
    first_time = time.time()
    first_examples = 0
    total_duration_time = 0.0
    with torch.no_grad():
        for i, (x, y) in enumerate(tqdm(test_loader, total=n_batches)):
            start_time = time.time()
            x = to_device(x, opt.device)
            y = to_device(y, opt.device)

            if opt.enable_ort:
                x = to_numpy(x)
                if config['emb_class'] == 'glove':
                    ort_inputs = {ort_session.get_inputs()[0].name: x}
                else:
                    if config['emb_class'] in [
                            'roberta', 'distilbert', 'bart'
                    ]:
                        ort_inputs = {
                            ort_session.get_inputs()[0].name: x[0],
                            ort_session.get_inputs()[1].name: x[1]
                        }
                    else:
                        ort_inputs = {
                            ort_session.get_inputs()[0].name: x[0],
                            ort_session.get_inputs()[1].name: x[1],
                            ort_session.get_inputs()[2].name: x[2]
                        }
                logits = ort_session.run(None, ort_inputs)[0]
                logits = to_device(torch.tensor(logits), opt.device)
            else:
                logits = model(x)

            if preds is None:
                preds = to_numpy(logits)
                ys = to_numpy(y)
            else:
                preds = np.append(preds, to_numpy(logits), axis=0)
                ys = np.append(ys, to_numpy(y), axis=0)
            predicted = logits.argmax(1)
            correct += (predicted == y).sum().item()
            cur_examples = y.size(0)
            total_examples += cur_examples
            if i == 0:  # first one may take longer time, so ignore in computing duration.
                first_time = float((time.time() - first_time) * 1000)
                first_examples = cur_examples
            if opt.num_examples != 0 and total_examples >= opt.num_examples:
                logger.info("[Stop Evaluation] : up to the {} examples".format(
                    total_examples))
                break
            duration_time = float((time.time() - start_time) * 1000)
            if i != 0: total_duration_time += duration_time
            '''
            logger.info("[Elapsed Time] : {}ms".format(duration_time))
            '''
    # generate report
    labels = config['labels']
    label_names = [v for k, v in sorted(labels.items(), key=lambda x: x[0])]
    preds_ids = np.argmax(preds, axis=1)
    try:
        print(
            classification_report(ys,
                                  preds_ids,
                                  target_names=label_names,
                                  digits=4))
        print(labels)
        print(confusion_matrix(ys, preds_ids))
    except Exception as e:
        logger.warn(str(e))

    acc = correct / total_examples
    whole_time = float((time.time() - whole_st_time) * 1000)
    avg_time = (whole_time - first_time) / (total_examples - first_examples)
    # write predictions to file
    write_prediction(opt, preds, labels)
    logger.info("[Accuracy] : {:.4f}, {:5d}/{:5d}".format(
        acc, correct, total_examples))
    logger.info("[Elapsed Time] : {}ms, {}ms on average".format(
        whole_time, avg_time))
    logger.info(
        "[Elapsed Time(total_duration_time, average)] : {}ms, {}ms".format(
            total_duration_time, total_duration_time / (total_examples - 1)))
Exemple #24
0
def main():
    if config.gpu and not torch.cuda.is_available():
        raise ValueError("GPU not supported or enabled on this system.")
    use_gpu = config.gpu

    log.info("Loading train dataset")
    train_dataset = COVIDxFolder(
        config.train_imgs, config.train_labels,
        transforms.train_transforms(config.width, config.height))
    train_loader = DataLoader(train_dataset,
                              batch_size=config.batch_size,
                              shuffle=True,
                              drop_last=True,
                              num_workers=config.n_threads,
                              pin_memory=use_gpu)
    log.info("Number of training examples {}".format(len(train_dataset)))

    log.info("Loading val dataset")
    val_dataset = COVIDxFolder(
        config.val_imgs, config.val_labels,
        transforms.val_transforms(config.width, config.height))
    val_loader = DataLoader(val_dataset,
                            batch_size=config.batch_size,
                            shuffle=False,
                            num_workers=config.n_threads,
                            pin_memory=use_gpu)
    log.info("Number of validation examples {}".format(len(val_dataset)))

    if config.weights:
        #state = torch.load(config.weights)
        state = None
        log.info("Loaded model weights from: {}".format(config.weights))
    else:
        state = None

    state_dict = state["state_dict"] if state else None
    model = architecture.COVIDNext50(n_classes=config.n_classes)
    if state_dict:
        model = util.load_model_weights(model=model, state_dict=state_dict)

    if use_gpu:
        model.cuda()
        model = torch.nn.DataParallel(model)
    optim_layers = filter(lambda p: p.requires_grad, model.parameters())

    # optimizer and lr scheduler
    optimizer = Adam(optim_layers,
                     lr=config.lr,
                     weight_decay=config.weight_decay)
    scheduler = ReduceLROnPlateau(optimizer=optimizer,
                                  factor=config.lr_reduce_factor,
                                  patience=config.lr_reduce_patience,
                                  mode='max',
                                  min_lr=1e-7)

    # Load the last global_step from the checkpoint if existing
    global_step = 0 if state is None else state['global_step'] + 1

    class_weights = util.to_device(torch.FloatTensor(config.loss_weights),
                                   gpu=use_gpu)
    loss_fn = CrossEntropyLoss()

    # Reset the best metric score
    best_score = -1
    for epoch in range(config.epochs):
        log.info("Started epoch {}/{}".format(epoch + 1, config.epochs))
        for data in train_loader:
            imgs, labels = data
            imgs = util.to_device(imgs, gpu=use_gpu)
            labels = util.to_device(labels, gpu=use_gpu)

            logits = model(imgs)
            loss = loss_fn(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if global_step % config.log_steps == 0 and global_step > 0:
                probs = model.module.probability(logits)
                preds = torch.argmax(probs, dim=1).detach().cpu().numpy()
                labels = labels.cpu().detach().numpy()
                acc, f1, _, _ = util.clf_metrics(preds, labels)
                lr = util.get_learning_rate(optimizer)

                log.info("Step {} | TRAINING batch: Loss {:.4f} | F1 {:.4f} | "
                         "Accuracy {:.4f} | LR {:.2e}".format(
                             global_step, loss.item(), f1, acc, lr))

            if global_step % config.eval_steps == 0 and global_step > 0:
                best_score = validate(val_loader,
                                      model,
                                      best_score=best_score,
                                      global_step=global_step,
                                      cfg=config)
                scheduler.step(best_score)
            global_step += 1
Exemple #25
0
def evaluate(opt):
    # set config
    config = load_config(opt)
    if opt.num_threads > 0: torch.set_num_threads(opt.num_threads)
    config['opt'] = opt
    logger.info("%s", config)

    # set path
    set_path(config)

    # prepare test dataset
    test_loader = prepare_datasets(config)

    # load pytorch model checkpoint
    checkpoint = load_checkpoint(config)

    # prepare model and load parameters
    model = load_model(config, checkpoint)
    model.eval()

    # convert to onnx format
    if opt.convert_onnx:
        (x, y) = next(iter(test_loader))
        x = to_device(x, opt.device)
        y = to_device(y, opt.device)
        convert_onnx(config, model, x)
        check_onnx(config)
        logger.info("[ONNX model saved at {}".format(opt.onnx_path))
        # quantize onnx
        if opt.quantize_onnx:
            quantize_onnx(opt.onnx_path, opt.quantized_onnx_path)
            logger.info("[Quantized ONNX model saved at {}".format(
                opt.quantized_onnx_path))
        return

    # load onnx model for using onnxruntime
    if opt.enable_ort:
        import onnxruntime as ort
        sess_options = ort.SessionOptions()
        sess_options.inter_op_num_threads = opt.num_threads
        sess_options.intra_op_num_threads = opt.num_threads
        ort_session = ort.InferenceSession(opt.onnx_path,
                                           sess_options=sess_options)

    # enable to use dynamic quantized model (pytorch>=1.3.0)
    if opt.enable_dqm and opt.device == 'cpu':
        model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear},
                                                    dtype=torch.qint8)
        print(model)

    # evaluation
    preds = None
    ys = None
    n_batches = len(test_loader)
    total_examples = 0
    whole_st_time = time.time()
    first_time = time.time()
    first_examples = 0
    total_duration_time = 0.0
    with torch.no_grad():
        for i, (x, y) in enumerate(tqdm(test_loader, total=n_batches)):
            start_time = time.time()
            x = to_device(x, opt.device)
            y = to_device(y, opt.device)

            if opt.enable_ort:
                x = to_numpy(x)
                if config['emb_class'] == 'glove':
                    ort_inputs = {
                        ort_session.get_inputs()[0].name: x[0],
                        ort_session.get_inputs()[1].name: x[1]
                    }
                    if opt.use_char_cnn:
                        ort_inputs[ort_session.get_inputs()[2].name] = x[2]
                if config['emb_class'] in [
                        'bert', 'distilbert', 'albert', 'roberta', 'bart',
                        'electra'
                ]:
                    if config['emb_class'] in ['distilbert', 'bart']:
                        ort_inputs = {
                            ort_session.get_inputs()[0].name: x[0],
                            ort_session.get_inputs()[1].name: x[1]
                        }
                    else:
                        ort_inputs = {
                            ort_session.get_inputs()[0].name: x[0],
                            ort_session.get_inputs()[1].name: x[1],
                            ort_session.get_inputs()[2].name: x[2]
                        }
                    if opt.bert_use_pos:
                        ort_inputs[ort_session.get_inputs()[3].name] = x[3]
                if opt.use_crf:
                    logits, prediction = ort_session.run(None, ort_inputs)
                    prediction = to_device(torch.tensor(prediction),
                                           opt.device)
                    logits = to_device(torch.tensor(logits), opt.device)
                else:
                    logits = ort_session.run(None, ort_inputs)[0]
                    logits = to_device(torch.tensor(logits), opt.device)
            else:
                if opt.use_crf: logits, prediction = model(x)
                else: logits = model(x)

            if preds is None:
                if opt.use_crf: preds = to_numpy(prediction)
                else: preds = to_numpy(logits)
                ys = to_numpy(y)
            else:
                if opt.use_crf:
                    preds = np.append(preds, to_numpy(prediction), axis=0)
                else:
                    preds = np.append(preds, to_numpy(logits), axis=0)
                ys = np.append(ys, to_numpy(y), axis=0)
            cur_examples = y.size(0)
            total_examples += cur_examples
            if i == 0:  # first one may take longer time, so ignore in computing duration.
                first_time = float((time.time() - first_time) * 1000)
                first_examples = cur_examples
            if opt.num_examples != 0 and total_examples >= opt.num_examples:
                logger.info("[Stop Evaluation] : up to the {} examples".format(
                    total_examples))
                break
            duration_time = float((time.time() - start_time) * 1000)
            if i != 0: total_duration_time += duration_time
            '''
            logger.info("[Elapsed Time] : {}ms".format(duration_time))
            '''
    whole_time = float((time.time() - whole_st_time) * 1000)
    avg_time = (whole_time - first_time) / (total_examples - first_examples)
    if not opt.use_crf: preds = np.argmax(preds, axis=2)
    # compute measure using seqeval
    labels = model.labels
    ys_lbs = [[] for _ in range(ys.shape[0])]
    preds_lbs = [[] for _ in range(ys.shape[0])]
    pad_label_id = config['pad_label_id']
    for i in range(ys.shape[0]):  # foreach sentence
        for j in range(ys.shape[1]):  # foreach token
            if ys[i][j] != pad_label_id:
                ys_lbs[i].append(labels[ys[i][j]])
                preds_lbs[i].append(labels[preds[i][j]])
    ret = {
        "precision": precision_score(ys_lbs, preds_lbs),
        "recall": recall_score(ys_lbs, preds_lbs),
        "f1": f1_score(ys_lbs, preds_lbs),
        "report": classification_report(ys_lbs, preds_lbs, digits=4),
    }
    print(ret['report'])
    f1 = ret['f1']
    # write predicted labels to file
    default_label = config['default_label']
    write_prediction(opt, ys, preds, labels, pad_label_id, default_label)

    logger.info("[F1] : {}, {}".format(f1, total_examples))
    logger.info("[Elapsed Time] : {} examples, {}ms, {}ms on average".format(
        total_examples, whole_time, avg_time))
    logger.info(
        "[Elapsed Time(total_duration_time, average)] : {}ms, {}ms".format(
            total_duration_time, total_duration_time / (total_examples - 1)))
Exemple #26
0
    def forward(ctx,
                ev_w_ih_x,
                ev_w_hh_x,
                ev_b_x,
                forgetgate_x,
                input_data,
                hx,
                cx,
                weight_ih,
                weight_hh,
                bias_ih=None,
                bias_hh=None):
        # calculate gates ...
        gates = (torch.mm(input_data, weight_ih.t()) + bias_ih +
                 torch.mm(hx, weight_hh.t()) + bias_hh)
        ingate, forgetgate_y, cellgate, outgate = gates.chunk(4, 1)

        # ... and gate activations
        ingate = torch.sigmoid(ingate)
        forgetgate_y = torch.sigmoid(forgetgate_y)
        cellgate = torch.tanh(cellgate)
        outgate = torch.sigmoid(outgate)

        cy = (forgetgate_y * cx) + (ingate * cellgate)
        hy = outgate * torch.tanh(cy)

        # TODO: calculate new eligibility vector and trace
        # There exist distinct eligibility traces and vectors for the followiug parts of the LSTM cell:
        # - input to hidden connections, hidden to hidden connections, bias
        # all for each:
        # - ... inputgate, forgetgate and cellgate
        # => overall 3 * 3 = 9 eligibility traces
        hidden_size = hy.size(1)
        batch_size = input_data.size(0)
        input_size = input_data.size(1)

        ingate = ingate.unsqueeze(2)
        forgetgate_y = forgetgate_y.unsqueeze(2)
        cellgate = cellgate.unsqueeze(2)
        outgate = outgate.unsqueeze(2)
        input_data = input_data.unsqueeze(1)
        hx = hx.unsqueeze(2)
        cx = cx.unsqueeze(2)

        forgetgate_x = forgetgate_x.repeat(1, 3, 1)
        ones = to_device(torch.ones(ingate.size()))

        # the new eligibility vectors ...
        ev_w_ih_y = ev_w_ih_x * forgetgate_x
        ev_w_hh_y = ev_w_hh_x * forgetgate_x
        ev_b_y = ev_b_x * forgetgate_x

        # ingate
        base = ingate * (ones - ingate) * cellgate
        ev_w_hh_y[:, :hidden_size, :] += base * hx
        ev_w_ih_y[:, :hidden_size, :] += base * input_data
        ev_b_y[:, :hidden_size, :] += base

        # forgetgate
        #base = forgetgate_y * (ones - forgetgate_y) * cellgate
        base = forgetgate_y * (ones - forgetgate_y) * cx
        ev_w_hh_y[:, hidden_size:(2 * hidden_size), :] += base * hx
        ev_w_ih_y[:, hidden_size:(2 * hidden_size), :] += base * input_data
        ev_b_y[:, hidden_size:(2 * hidden_size), :] += base

        # cellgate
        base = ingate * (ones - cellgate**2)
        ev_w_hh_y[:, (2 * hidden_size):(3 * hidden_size), :] += base * hx
        ev_w_ih_y[:,
                  (2 * hidden_size):(3 * hidden_size), :] += base * input_data
        ev_b_y[:, (2 * hidden_size):(3 * hidden_size), :] += base

        # ... and eligibility traces
        et_w_ih_y = to_device(
            torch.zeros(batch_size,
                        4 * hidden_size,
                        input_size,
                        requires_grad=False))
        et_w_hh_y = to_device(
            torch.zeros(batch_size,
                        4 * hidden_size,
                        hidden_size,
                        requires_grad=False))
        et_b_y = to_device(
            torch.zeros(batch_size, 4 * hidden_size, 1, requires_grad=False))

        # calculate eligibility traces by multiplying the eligibility vectors with the outgate
        tmp_outgate = outgate.repeat(1, 3, 1)
        et_w_ih_y[:, :3 * hidden_size, :] = ev_w_ih_y * tmp_outgate
        et_w_hh_y[:, :3 * hidden_size, :] = ev_w_hh_y * tmp_outgate
        et_b_y[:, :3 * hidden_size, :] = ev_b_y * tmp_outgate

        # The gradient of the output gate is only dependent on the observable state
        # => just use normal gradient calculation of dE/dh * dh/dweight
        # => calculate second part of that equation now for input to hidden, hidden to hidden
        #    and bias connections and multiply in the backward pass
        base = outgate * (ones - outgate) * cy.unsqueeze(2)
        et_w_hh_y[:, (3 * hidden_size):(4 * hidden_size)] = base * hx
        et_w_ih_y[:, (3 * hidden_size):(4 * hidden_size)] = base * input_data
        et_b_y[:, (3 * hidden_size):(4 * hidden_size)] = base

        ctx.save_for_backward(et_w_ih_y, et_w_hh_y, et_b_y, weight_hh,
                              cx.clone().squeeze(),
                              cy.clone().squeeze(), outgate.squeeze(),
                              ingate.squeeze(), cellgate.squeeze(),
                              forgetgate_y.squeeze())

        return hy, cy, ev_w_ih_y, ev_w_hh_y, ev_b_y, forgetgate_y
Exemple #27
0
def inference(opt):
    # set config
    config = load_config(opt)
    if opt.num_threads > 0: torch.set_num_threads(opt.num_threads)
    config['opt'] = opt

    # set path: opt.embedding_path, opt.vocab_path, opt.label_path
    set_path(config)

    # load pytorch model checkpoint
    checkpoint = load_checkpoint(opt.model_path, device=opt.device)

    # prepare model and load parameters
    model = load_model(config, checkpoint)
    model.eval()

    # load onnx model for using onnxruntime
    if opt.enable_ort:
        import onnxruntime as ort
        sess_options = ort.SessionOptions()
        sess_options.inter_op_num_threads = opt.num_threads
        sess_options.intra_op_num_threads = opt.num_threads
        ort_session = ort.InferenceSession(opt.onnx_path,
                                           sess_options=sess_options)

    # enable to use dynamic quantized model (pytorch>=1.3.0)
    if opt.enable_dqm and opt.device == 'cpu':
        model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear},
                                                    dtype=torch.qint8)
        print(model)

    # prepare tokenizer
    tokenizer = prepare_tokenizer(config, model)

    # prepare labels
    labels = config['labels']

    # inference
    f_out = open(opt.test_path + '.inference', 'w', encoding='utf-8')
    total_examples = 0
    total_duration_time = 0.0
    with torch.no_grad(), open(opt.test_path, 'r', encoding='utf-8') as f:
        for i, line in enumerate(f):
            start_time = time.time()
            sent, label = line.strip().split('\t')
            x_raw = sent.split()
            y_raw = label
            text = ' '.join(x_raw)
            x = encode_text(config, tokenizer, text)
            x = to_device(x, opt.device)

            if opt.enable_ort:
                x = to_numpy(x)
                if config['emb_class'] == 'glove':
                    ort_inputs = {ort_session.get_inputs()[0].name: x}
                else:
                    if config['emb_class'] in [
                            'roberta', 'distilbert', 'bart'
                    ]:
                        ort_inputs = {
                            ort_session.get_inputs()[0].name: x[0],
                            ort_session.get_inputs()[1].name: x[1]
                        }
                    else:
                        ort_inputs = {
                            ort_session.get_inputs()[0].name: x[0],
                            ort_session.get_inputs()[1].name: x[1],
                            ort_session.get_inputs()[2].name: x[2]
                        }
                logits = ort_session.run(None, ort_inputs)[0]
                logits = to_device(torch.tensor(logits), opt.device)
            else:
                logits = model(x)

            predicted = logits.argmax(1)
            predicted = to_numpy(predicted)[0]
            predicted_raw = labels[predicted]
            f_out.write(text + '\t' + y_raw + '\t' + predicted_raw + '\n')
            total_examples += 1
            if opt.num_examples != 0 and total_examples >= opt.num_examples:
                logger.info("[Stop Inference] : up to the {} examples".format(
                    total_examples))
                break
            duration_time = float((time.time() - start_time) * 1000)
            if i != 0: total_duration_time += duration_time
            logger.info("[Elapsed Time] : {}ms".format(duration_time))
    f_out.close()
    logger.info(
        "[Elapsed Time(total_duration_time, average)] : {}ms, {}ms".format(
            total_duration_time, total_duration_time / (total_examples - 1)))
Exemple #28
0
def train_epoch(model, config, train_loader, valid_loader, epoch_i,
                best_eval_measure):
    optimizer = config['optimizer']
    scheduler = config['scheduler']
    writer = config['writer']
    scaler = config['scaler']
    opt = config['opt']

    if opt.criterion == 'MSELoss':
        criterion = torch.nn.MSELoss(reduction='sum').to(opt.device)
    elif opt.criterion == 'KLDivLoss':
        criterion = torch.nn.KLDivLoss(reduction='sum').to(opt.device)
    elif opt.criterion == 'LabelSmoothingCrossEntropy':
        criterion = LabelSmoothingCrossEntropy(reduction='sum').to(opt.device)
    else:
        criterion = torch.nn.CrossEntropyLoss().to(opt.device)

    # train one epoch
    total_loss = 0
    avg_loss = 0
    local_best_eval_loss = float('inf')
    local_best_eval_acc = 0
    total_examples = 0
    st_time = time.time()
    optimizer.zero_grad()
    epoch_iterator = tqdm(train_loader,
                          total=len(train_loader),
                          desc=f"Epoch {epoch_i}")
    for local_step, (x, y) in enumerate(epoch_iterator):
        model.train()
        global_step = (len(train_loader) * epoch_i) + local_step
        x = to_device(x, opt.device)
        y = to_device(y, opt.device)
        with autocast(enabled=opt.use_amp):
            if opt.use_profiler:
                with profiler.profile(profile_memory=True,
                                      record_shapes=True) as prof:
                    output = model(x)
                print(prof.key_averages().table(
                    sort_by="self_cpu_memory_usage", row_limit=10))
            else:
                output = model(x)
            if opt.criterion == 'KLDivLoss':
                loss = criterion(F.log_softmax(output, dim=1), y)
            else:
                loss = criterion(output, y)
            if opt.gradient_accumulation_steps > 1:
                loss = loss / opt.gradient_accumulation_steps
        # back-propagation - begin
        if opt.device == 'cpu':
            loss.backward()
        else:
            scaler.scale(loss).backward()
        if (local_step + 1) % opt.gradient_accumulation_steps == 0:
            if opt.device == 'cpu':
                optimizer.step()
            else:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               opt.max_grad_norm)
                scaler.step(optimizer)
                scaler.update()
            optimizer.zero_grad()
            scheduler.step()
            curr_lr = scheduler.get_last_lr(
            )[0] if scheduler else optimizer.param_groups[0]['lr']
            epoch_iterator.set_description(
                f"Epoch {epoch_i}, local_step: {local_step}, loss: {loss:.3f}, curr_lr: {curr_lr:.7f}"
            )
            if opt.eval_and_save_steps > 0 and global_step != 0 and global_step % opt.eval_and_save_steps == 0:
                # evaluate
                eval_loss, eval_acc = evaluate(model, config, valid_loader)
                if local_best_eval_loss > eval_loss:
                    local_best_eval_loss = eval_loss
                if local_best_eval_acc < eval_acc:
                    local_best_eval_acc = eval_acc
                if writer:
                    writer.add_scalar('Loss/valid', eval_loss, global_step)
                    writer.add_scalar('Acc/valid', eval_acc, global_step)
                    writer.add_scalar('LearningRate/train', curr_lr,
                                      global_step)
                if opt.measure == 'loss': eval_measure = eval_loss
                else: eval_measure = eval_acc
                if opt.measure == 'loss':
                    is_best = eval_measure < best_eval_measure
                else:
                    is_best = eval_measure > best_eval_measure
                if is_best:
                    best_eval_measure = eval_measure
                    if opt.save_path and not opt.hp_search_optuna and not opt.hp_search_nni:
                        logger.info("[Best model saved] : {}, {}".format(
                            eval_loss, eval_acc))
                        save_model(config, model, valid_loader=valid_loader)
                        # save finetuned bert model/config/tokenizer
                        if config['emb_class'] not in ['glove']:
                            if not os.path.exists(opt.bert_output_dir):
                                os.makedirs(opt.bert_output_dir)
                            model.bert_tokenizer.save_pretrained(
                                opt.bert_output_dir)
                            model.bert_model.save_pretrained(
                                opt.bert_output_dir)
        # back-propagation - end
        cur_examples = y.size(0)
        total_examples += cur_examples
        total_loss += (loss.item() * cur_examples)
        if writer: writer.add_scalar('Loss/train', loss.item(), global_step)
    avg_loss = total_loss / total_examples

    # evaluate at the end of epoch
    eval_loss, eval_acc = evaluate(model, config, valid_loader)
    if local_best_eval_loss > eval_loss: local_best_eval_loss = eval_loss
    if local_best_eval_acc < eval_acc: local_best_eval_acc = eval_acc
    if writer:
        writer.add_scalar('Loss/valid', eval_loss, global_step)
        writer.add_scalar('Acc/valid', eval_acc, global_step)
        writer.add_scalar('LearningRate/train', curr_lr, global_step)
    if opt.measure == 'loss': eval_measure = eval_loss
    else: eval_measure = eval_acc
    if opt.measure == 'loss': is_best = eval_measure < best_eval_measure
    else: is_best = eval_measure > best_eval_measure
    if is_best:
        best_eval_measure = eval_measure
        if opt.save_path and not opt.hp_search_optuna and not opt.hp_search_nni:
            logger.info("[Best model saved] : {}, {}".format(
                eval_loss, eval_acc))
            save_model(config, model, valid_loader=valid_loader)
            # save finetuned bert model/config/tokenizer
            if config['emb_class'] not in ['glove']:
                if not os.path.exists(opt.bert_output_dir):
                    os.makedirs(opt.bert_output_dir)
                model.bert_tokenizer.save_pretrained(opt.bert_output_dir)
                model.bert_model.save_pretrained(opt.bert_output_dir)

    curr_time = time.time()
    elapsed_time = (curr_time - st_time) / 60
    st_time = curr_time
    logs = {
        'epoch': epoch_i,
        'local_step': local_step + 1,
        'epoch_step': len(train_loader),
        'avg_loss': avg_loss,
        'local_best_eval_loss': local_best_eval_loss,
        'local_best_eval_acc': local_best_eval_acc,
        'best_eval_measure': best_eval_measure,
        'elapsed_time': elapsed_time
    }
    logger.info(json.dumps(logs, indent=4, ensure_ascii=False, sort_keys=True))

    return local_best_eval_loss, local_best_eval_acc, best_eval_measure