예제 #1
0
def valid(x, y, batch_size, model):
    x_batches = batch_list(x, batch_size)
    y_batches = batch_list(y, batch_size)

    total_loss, iter_num, val_acc = 0, 0, 0
    model.eval()

    with torch.no_grad():
        for step, data in enumerate(zip(x_batches, y_batches)):
            inputs, labels = np.array(data[0]), torch.tensor(data[1])

            inputs = inputs.reshape(inputs.shape[0], 1, 100, 71)
            inputs = torch.from_numpy(
                inputs).float()  #torch.Size([128, 1, 100, 71])

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, np.argmax(labels, -1))

            total_loss += loss
            iter_num += 1

            v_acc = acc(torch.max(outputs, 1)[1], np.argmax(labels, -1))
            val_acc += v_acc

        return total_loss / iter_num, val_acc / iter_num
def train(model, iterator, optimizer, criterion, args, bert_tok):
    total_loss = 0
    iter_num = 0
    train_acc = 0
    global iteration

    model.train()

    if args.useKey == 'True':
        keyword, refine_idx = keyword_loader(args, 'train', bert_tok)

    for step, batch in enumerate(iterator):

        optimizer.zero_grad()

        enc_inputs = batch.que

        copy_dec_inputs = copy.deepcopy(batch.ans)
        copy_dec_target = copy.deepcopy(batch.ans)

        dec_inputs = get_dec_inputs(copy_dec_inputs, gpt_pad_token,
                                    gpt_eos_token)
        target_ = get_target(copy_dec_target, gpt_pad_token)
        target_ = target_.view(-1)

        segment_ids, valid_len = get_segment_ids_vaild_len(
            enc_inputs, pad_token_idx)
        attention_mask = gen_attention_mask(enc_inputs, valid_len)

        if args.useKey == 'True':
            outputs = model(enc_inputs, dec_inputs, segment_ids,
                            attention_mask, keyword[step], refine_idx[step])
        else:
            outputs = model(enc_inputs, dec_inputs, segment_ids,
                            attention_mask, None, refine_idx[step])

        loss = criterion(outputs, target_)

        loss.backward()
        optimizer.step()

        total_loss += loss
        iter_num += 1
        with torch.no_grad():
            tr_acc = acc(outputs, target_, gpt_pad_token)
        train_acc += tr_acc

        if step % 2 == 0:
            total_train_loss.append(total_loss.data.cpu().numpy() / iter_num)
            iteration_list.append(iteration)
            iteration += 1

    return total_loss.data.cpu().numpy() / iter_num, train_acc.data.cpu(
    ).numpy() / iter_num
def valid(model, iterator, optimizer, criterion, args, bert_tok):
    total_loss = 0
    iter_num = 0
    test_acc = 0
    model.eval()

    if args.useKey == 'True':
        keyword, refine_idx = keyword_loader(args, 'valid', bert_tok)

    with torch.no_grad():
        for step, batch in enumerate(iterator):
            enc_inputs = batch.que

            copy_dec_inputs = copy.deepcopy(batch.ans)
            copy_dec_target = copy.deepcopy(batch.ans)

            dec_inputs = get_dec_inputs(copy_dec_inputs, gpt_pad_token,
                                        gpt_eos_token)
            target_ = get_target(copy_dec_target, gpt_pad_token)
            target_ = target_.view(-1)

            segment_ids, valid_len = get_segment_ids_vaild_len(
                enc_inputs, pad_token_idx)
            attention_mask = gen_attention_mask(enc_inputs, valid_len)

            if args.useKey == 'True':
                outputs = model(enc_inputs, dec_inputs, segment_ids,
                                attention_mask, keyword[step],
                                refine_idx[step])
            else:
                outputs = model(enc_inputs, dec_inputs, segment_ids,
                                attention_mask, None, refine_idx[step])

            loss = criterion(outputs, target_)

            total_loss += loss
            iter_num += 1
            te_acc = acc(outputs, target_, gpt_pad_token)

            test_time_visual(args, enc_inputs, outputs, target_,
                             bert_tokenizer, gpt_vocab)
            test_acc += te_acc

        return total_loss.data.cpu().numpy() / iter_num, test_acc.data.cpu(
        ).numpy() / iter_num