Пример #1
0
                    # compute the accuracy
                    acc_t = compute_accuracy(y_pred,
                                             batch_dict['y_target'],
                                             output_type=args.output_type)
                    running_acc += (acc_t - running_acc) / (batch_index + 1)

                    val_bar.set_postfix(loss=running_loss,
                                        acc=running_acc,
                                        epoch=epoch_index)
                    val_bar.update()

                train_state['val_loss'].append(running_loss)
                train_state['val_acc'].append(running_acc)

                train_state = update_train_state(args=args,
                                                 model=classifier,
                                                 train_state=train_state)

                scheduler.step(train_state['val_loss'][-1])

                train_bar.n = 0
                val_bar.n = 0
                epoch_bar.update()

                if train_state['stop_early']:
                    break

                train_bar.n = 0
                val_bar.n = 0
                epoch_bar.update()
Пример #2
0
def main(args):
    set_envs(args)
    print("Using: {}".format(args.device))

    total_items = os.popen(f'ls -l {args.hotpotQA_item_folder} |grep "^-"| wc -l').readlines()[0].strip()
    total_items = int(total_items)
    print(f"total items: {total_items}")

    # 实例化
    classifier = GAT_HotpotQA(features=args.features, hidden=args.hidden, nclass=args.nclass, 
                                dropout=args.dropout, alpha=args.alpha, nheads=args.nheads, 
                                nodes_num=args.pad_max_num)

    class_weights_sent, class_weights_para, class_weights_Qtype = \
                HotpotQA_GNN_Dataset.get_weights(device=args.device)
    loss_func_sent = nn.CrossEntropyLoss(class_weights_sent,ignore_index=-100)
    loss_func_para = nn.CrossEntropyLoss(class_weights_para,ignore_index=-100)
    loss_func_Qtype = nn.CrossEntropyLoss(class_weights_Qtype)

    optimizer = optim.Adam(filter(lambda p: p.requires_grad, classifier.parameters()),
                        lr=args.learning_rate)

    # Initialization
    opt_level = 'O1'
    if args.cuda:
        classifier = classifier.to(args.device)
        if args.fp16:
            classifier, optimizer = amp.initialize(classifier, optimizer, opt_level=opt_level)
        classifier = nn.parallel.DistributedDataParallel(classifier,
                                                        device_ids=args.device_ids, 
                                                        output_device=0, 
                                                        find_unused_parameters=True)
        
    if args.reload_from_files:
        checkpoint = torch.load(args.model_state_file)
        classifier.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        if args.fp16: amp.load_state_dict(checkpoint['amp'])

    train_state = make_train_state(args)

    try:
        writer = SummaryWriter(log_dir=args.log_dir, flush_secs=args.flush_secs)
        cursor_train = 0
        cursor_val = 0
        if args.reload_from_files and 'cursor_train' in checkpoint.keys():
            cursor_train = checkpoint['cursor_train'] + 1
            cursor_val = checkpoint['cursor_val'] + 1
        # for epoch_index in range(args.num_epochs):
        if args.chunk_size < 0:
            args.chunk_size = total_items
        for chunk_i in range(0, total_items, args.chunk_size):
            dataset = HotpotQA_GNN_Dataset.build_dataset(hotpotQA_item_folder = args.hotpotQA_item_folder,
                                                        i_from = chunk_i, 
                                                        i_to = chunk_i+args.chunk_size,
                                                        seed=args.seed+chunk_i)

            scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer,
                            mode='min',
                            factor=0.7,
                            patience=dataset.get_num_batches(args.batch_size)/10)

            dataset.set_parameters(args.pad_max_num, args.pad_value)
            epoch_bar = tqdm(desc='training routine',
                            total=args.num_epochs,
                            position=0)

            dataset.set_split('train')
            train_bar = tqdm(desc='split=train',
                            total=dataset.get_num_batches(args.batch_size), 
                            position=1)

            dataset.set_split('val')
            val_bar = tqdm(desc='split=val',
                            total=dataset.get_num_batches(args.batch_size), 
                            position=1)

            for epoch_index in range(args.num_epochs):
                train_bar.n = 0
                val_bar.n = 0
                train_state['epoch_index'] = epoch_index
                dataset.set_split('train')
                batch_generator = gen_GNN_batches(dataset,
                                                batch_size=args.batch_size, 
                                                device=args.device)
                running_loss = 0.0
                running_acc_Qtype = 0.0
                running_acc_topN = 0.0

                classifier.train()
                for batch_index, batch_dict in enumerate(batch_generator):
                    optimizer.zero_grad()

                    logits_sent, logits_para, logits_Qtype = \
                                    classifier(batch_dict['feature_matrix'], batch_dict['adj'])

                    # topN sents
                    max_value, max_index = logits_sent.max(dim=-1) # max_index is predict class.
                    topN_sent_index_batch = (max_value * batch_dict['sent_mask'].squeeze()).topk(args.topN_sents, dim=-1)[1]
                    topN_sent_predict = torch.gather(max_index, -1, topN_sent_index_batch)
                    topN_sent_label = torch.gather((batch_dict['labels'] * batch_dict['sent_mask']).squeeze(),
                                                    -1, 
                                                    topN_sent_index_batch)

                    logits_sent = (batch_dict['sent_mask'] * logits_sent).view(-1,2)

                    labels_sent = (batch_dict['sent_mask']*batch_dict['labels'] + \
                                batch_dict['sent_mask'].eq(0)*-100).view(-1)

                    logits_para = (batch_dict['para_mask'] * logits_para).view(-1,2)

                    labels_para = (batch_dict['para_mask']*batch_dict['labels'] + \
                                batch_dict['para_mask'].eq(0)*-100).view(-1)

                    loss_sent = loss_func_sent(logits_sent, labels_sent) # [B,2] [B]
                    loss_para = loss_func_para(logits_para, labels_para) # [B,2] [B]
                    loss_Qtype = loss_func_Qtype(logits_Qtype.view(-1,2),
                                                batch_dict['answer_type'].view(-1)) # [B,2] [B]

                    loss = loss_sent + loss_para + loss_Qtype
                    running_loss += (loss.item() - running_loss) / (batch_index + 1)

                    if args.fp16:
                        with amp.scale_loss(loss, optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()

                    optimizer.step()
                    scheduler.step(running_loss)

                    # compute the recall
                    recall_t_sent = compute_recall(logits_sent.view(-1,2), 
                                                batch_dict['labels'].view(-1), 
                                                batch_dict['sent_mask'].view(-1))

                    recall_t_para = compute_recall(logits_para.view(-1,2), 
                                                batch_dict['labels'].view(-1), 
                                                batch_dict['para_mask'].view(-1))
                    # compute the acc
                    acc_t_Qtype = compute_accuracy(logits_Qtype.view(-1,2), 
                                                batch_dict['answer_type'].view(-1))
                    running_acc_Qtype += (acc_t_Qtype - running_acc_Qtype) / (batch_index + 1)

                    acc_t_topN = compute_accuracy(predict=topN_sent_predict.view(-1), 
                                                labels=topN_sent_label.view(-1))
                    running_acc_topN += (acc_t_topN - running_acc_topN) / (batch_index + 1)

                    # update bar
                    train_bar.set_postfix(loss=running_loss,epoch=epoch_index)
                    train_bar.update()
                    
                    writer.add_scalar('loss/train', loss.item(), cursor_train)
                    writer.add_scalar('recall_t_sent/train', recall_t_sent, cursor_train)
                    writer.add_scalar('recall_t_para/train', recall_t_para, cursor_train)

                    writer.add_scalar('running_acc_Qtype/train', running_acc_Qtype, cursor_train)
                    writer.add_scalar('running_acc_topN/train', running_acc_topN, cursor_train)
                    writer.add_scalar('running_loss/train', running_loss, cursor_train)

                    cursor_train += 1

                train_state['train_running_loss'].append(running_loss)

                # Iterate over val dataset
                # setup: batch generator, set loss and acc to 0; set eval mode on

                dataset.set_split('val')
                batch_generator = gen_GNN_batches(dataset,
                                                batch_size=args.batch_size, 
                                                device=args.device)
                running_loss = 0.0
                running_acc_Qtype = 0.0
                running_acc_topN = 0.0
                classifier.eval()

                for batch_index, batch_dict in enumerate(batch_generator):
                    # compute the output
                    with torch.no_grad():

                        logits_sent, logits_para, logits_Qtype = \
                                        classifier(batch_dict['feature_matrix'], batch_dict['adj'])
                        
                        max_value, max_index = logits_sent.max(dim=-1) # max_index is predict class.
                        topN_sent_index_batch = (max_value * batch_dict['sent_mask'].squeeze()).topk(args.topN_sents, dim=-1)[1]
                        topN_sent_predict = torch.gather(max_index, -1, topN_sent_index_batch)

                        topN_sent_label = torch.gather((batch_dict['labels'] * batch_dict['sent_mask']).squeeze(),
                                                        -1, 
                                                        topN_sent_index_batch)

                        logits_sent = (batch_dict['sent_mask'] * logits_sent).view(-1,2)

                        labels_sent = (batch_dict['sent_mask']*batch_dict['labels'] + \
                                    batch_dict['sent_mask'].eq(0)*-100).view(-1)

                        logits_para = (batch_dict['para_mask'] * logits_para).view(-1,2)

                        labels_para = (batch_dict['para_mask']*batch_dict['labels'] + \
                                    batch_dict['para_mask'].eq(0)*-100).view(-1)

                        loss_sent = loss_func_sent(logits_sent, labels_sent) # [B,2] [B]
                        loss_para = loss_func_para(logits_para, labels_para) # [B,2] [B]
                        loss_Qtype = loss_func_Qtype(logits_Qtype.view(-1,2),
                                                    batch_dict['answer_type'].view(-1)) # [B,2] [B]

                        loss = loss_sent + loss_para + loss_Qtype

                        try:
                            running_loss += (loss.item() - running_loss) / (batch_index + 1)
                        except RuntimeError:
                            print(f"err batch: {batch_index}")
                            print_exc()
                            exit()

                    # compute the recall
                    recall_t_sent = compute_recall(logits_sent.view(-1,2), 
                                                batch_dict['labels'].view(-1), 
                                                batch_dict['sent_mask'].view(-1))

                    recall_t_para = compute_recall(logits_para.view(-1,2), 
                                                batch_dict['labels'].view(-1), 
                                                batch_dict['para_mask'].view(-1))
                    # compute the acc
                    acc_t_Qtype = compute_accuracy(logits_Qtype.view(-1,2), 
                                                batch_dict['answer_type'].view(-1))
                    running_acc_Qtype += (acc_t_Qtype - running_acc_Qtype) / (batch_index + 1)

                    acc_t_topN = compute_accuracy(predict=topN_sent_predict.view(-1), 
                                                labels=topN_sent_label.view(-1))
                    running_acc_topN += (acc_t_topN - running_acc_topN) / (batch_index + 1)


                    # update bar
                    val_bar.set_postfix(loss=running_loss,epoch=epoch_index)
                    val_bar.update()

                    writer.add_scalar('loss/val', loss.item(), cursor_val)
                    writer.add_scalar('recall_t_sent/val', recall_t_sent, cursor_val)
                    writer.add_scalar('recall_t_para/val', recall_t_para, cursor_val)

                    writer.add_scalar('running_acc_Qtype/val', running_acc_Qtype, cursor_val)
                    writer.add_scalar('running_acc_topN/val', running_acc_topN, cursor_val)
                    writer.add_scalar('running_loss/val', running_loss, cursor_val)
                    cursor_val += 1


                train_state['val_running_loss'].append(running_loss)

                train_state = update_train_state(args=args, model=classifier, 
                                                optimizer = optimizer,
                                                train_state=train_state)
                epoch_bar.update()

                # if train_state['stop_early']:
                #     print('STOP EARLY!')
                #     exit()

                # epoch done.
            # chunk done.
        # all finished.
    except KeyboardInterrupt:
        print("Exiting loop")
    except :
        print(f"err in chunk {chunk_i}, epoch_index {epoch_index}, batch_index {batch_index}.")
        print_exc()
def train(model, train_dataset, val_dataset, args: Namespace):
    """
    :param model: model pytorch
    :param train_dataset:
    :param val_dataset:
    :param args:
    :return:
    """

    checkpoint_path = args.model_dir + '/best_model.pth'
    if args.checkpoint and os.path.exists(checkpoint_path):
        print("load model from " + checkpoint_path)
        model.load_state_dict(torch.load(checkpoint_path))

    if os.path.exists(args.model_dir) is False:
        os.mkdir(args.model_dir)

    device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")
    print("device: ", device)

    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              drop_last=True)

    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-3)
    # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1)
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=0.01,
        steps_per_epoch=len(train_loader),
        epochs=args.num_epochs)

    train_state = make_train_state(args)

    for epoch in range(args.num_epochs):
        start_time = time.time()
        print(32 * '_' + ' ' * 9 + 31 * '_')
        print(32 * '_' + 'EPOCH {:3}'.format(epoch + 1) + 31 * '_')
        train_state['epoch_index'] = epoch
        train_loader = DataLoader(dataset=train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  drop_last=True)
        train_loss = 0.
        model.train()
        for i, (x_vector, y_vector, x_mask) in enumerate(train_loader):
            optimizer.zero_grad()
            out = model(x_vector.to(device), x_mask.to(device))
            loss = model.compute_loss(out, y_vector.to(device),
                                      x_mask.to(device))
            loss.backward()

            # gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)

            optimizer.step()
            train_loss += (loss.item() - train_loss) / (i + 1)
            if (i + 1) % 10 == 0:
                print("|{:4}| train_loss = {:.4f}".format(i + 1, loss.item()))

        # writer.add_scalar('Train/Loss', train_loss, epoch)
        train_state['train_loss'].append(train_loss)
        optimizer.zero_grad()
        val_loss, f1, acc = model.evaluate(val_dataset,
                                           thresh=args.thresh,
                                           batch_size=args.batch_size)
        # scheduler.step(val_loss)
        print('Val loss: {:.4f}'.format(val_loss))
        scheduler.step()

        train_state['val_loss'].append(-f1)

        ## update train state
        train_state = update_train_state(model, train_state)

        if train_state['stop_early']:
            print('Stop early.......!')
            break

    return model
def train(model, train_dataset, val_dataset, args: Namespace):
    """
    :param model: model pytorch
    :param train_dataset:
    :param val_dataset:
    :param args:
    :return:
    """

    checkpoint_path = args.model_dir + '/last_checkpoint.pth'
    if args.checkpoint and os.path.exists(checkpoint_path):
        print("load model from " + checkpoint_path)
        model.load_state_dict(torch.load(checkpoint_path))

    if os.path.exists(args.model_dir) is False:
        os.mkdir(args.model_dir)

    device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")
    print("device: ", device)

    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=args.batch_size,
                              shuffle=True,
                              drop_last=True)
    val_loader = DataLoader(dataset=val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            drop_last=True)

    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-3)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     mode='min',
                                                     factor=0.5,
                                                     patience=1)
    #scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(train_loader), epochs=args.num_epochs)

    val_loader = DataLoader(dataset=val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            drop_last=True)
    train_state = make_train_state(args)

    for epoch in range(args.num_epochs):
        start_time = time.time()
        print("Epoch: ", epoch + 1)
        train_state['epoch_index'] = epoch
        train_loader = DataLoader(dataset=train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True,
                                  drop_last=True)
        train_loss = 0.
        model.train()
        for i, (x_vector, y_vector, x_mask) in enumerate(train_loader):
            optimizer.zero_grad()
            loss, out = model.forward_loss(x_vector.to(device),
                                           y_vector.to(device),
                                           x_mask.to(device))
            loss.backward()

            # gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)

            optimizer.step()
            train_loss += (loss.item() - train_loss) / (i + 1)
            if (i + 1) % 10 == 0:
                print("\tStep: {} train_loss: {}".format(i + 1, loss.item()))

        # writer.add_scalar('Train/Loss', train_loss, epoch)
        train_state['train_loss'].append(train_loss)

        model.eval()
        val_loss = 0.
        y_pred = []
        y_true = []

        for i, (x_vector, y_vector, x_mask) in enumerate(val_loader):
            optimizer.zero_grad()
            loss, out = model.forward_loss(x_vector.to(device),
                                           y_vector.to(device),
                                           x_mask.to(device))
            val_loss += (loss.item() - val_loss) / (i + 1)
            scheduler.step(val_loss)
            y_pred.append((out > args.thresh).long())
            y_true.append(y_vector.long())

            train_state['val_loss'].append(val_loss)
            train_state = update_train_state(model, train_state)

        y_true = torch.cat(y_true, dim=-1).cpu().detach().numpy()
        y_pred = torch.cat(y_pred, dim=-1).cpu().detach().numpy()

        acc, sub_acc, f1, precision, recall, hamming_loss = get_multi_label_metrics(
            y_true=y_true, y_pred=y_pred)
        print('f1: {}    precision: {}    recall: {}'.format(
            f1, precision, recall))
        print('accuracy: {}    sub accuracy : {}    hamming loss: {}'.format(
            acc, sub_acc, hamming_loss))
        # save best model.
        torch.save(model.state_dict(), args.model_dir + '/' + args.model_name)