コード例 #1
0
def train(dm_train_set, dm_test_set):

    EMBEDDING_DIM = 200
    batch_size = 128
    epoch_num = 100
    max_acc = 0
    model_save_path = '.tmp/model_save/triplet_embed.model'

    dm_dataloader = data.DataLoader(dataset=dm_train_set,
                                    batch_size=batch_size,
                                    shuffle=True,
                                    drop_last=True,
                                    num_workers=8)

    dm_test_dataloader = data.DataLoader(dataset=dm_test_set,
                                         batch_size=batch_size,
                                         shuffle=True,
                                         drop_last=False,
                                         num_workers=8)

    model = EmbeddingE2EModeler(dm_train_set.vocab_size(), EMBEDDING_DIM)
    print(model)
    init_weight = np.loadtxt("./tmp/24581_we_weights.txt")
    model.init_emb(init_weight)
    if torch.cuda.is_available():
        print("CUDA : On")
        model.cuda()
    else:
        print("CUDA : Off")

    embedding_params = list(map(id, model.embedding.parameters()))
    other_params = filter(lambda p: id(p) not in embedding_params,
                          model.parameters())

    optimizer = optim.Adam([{
        'params': other_params
    }, {
        'params': model.embedding.parameters(),
        'lr': 1e-4
    }],
                           lr=1e-3,
                           betas=(0.9, 0.99))

    logging = False
    if logging:
        writer = SummaryWriter()
        log_name = 'Triplet_embed'

    history = None

    for epoch in range(epoch_num):
        for batch_idx, sample_dict in enumerate(dm_dataloader):
            anchor = Variable(torch.LongTensor(sample_dict['anchor']))
            pos = Variable(torch.LongTensor(sample_dict['pos']))
            neg = Variable(torch.LongTensor(sample_dict['neg']))
            label = Variable(torch.LongTensor(sample_dict['label']))
            mask = Variable(torch.LongTensor(sample_dict['mask']))
            mask_ = mask.type(torch.FloatTensor).view(-1)
            if torch.cuda.is_available():
                anchor = anchor.cuda()
                pos = pos.cuda()
                neg = neg.cuda()
                label = label.cuda()
                mask = mask.cuda()
                mask_ = mask_.cuda()

            optimizer.zero_grad()
            anchor_embed = model.embed(anchor)
            pos_embed = model.embed(pos)
            neg_embed = model.embed(neg)
            triplet_loss = nn.TripletMarginLoss(margin=10, p=2)
            embedding_loss = triplet_loss(anchor_embed, pos_embed, neg_embed)
            anchor_pred = model.forward(anchor).unsqueeze(1)
            pos_pred = model.forward(pos).unsqueeze(1)
            neg_pred = model.forward(neg).unsqueeze(1)
            final_pred = torch.cat((anchor_pred, pos_pred, neg_pred), dim=1)
            final_pred = final_pred.view(1, -1, 2)
            final_pred = final_pred.squeeze()

            cross_entropy = nn.NLLLoss(reduction='none')
            label = label.mul(mask)
            label = label.view(-1)
            classify_loss = cross_entropy(F.log_softmax(final_pred, dim=1),
                                          label)
            classify_loss = classify_loss.mul(mask_)
            if mask_.sum() > 0:
                classify_loss = classify_loss.sum() / mask_.sum()
            else:
                classify_loss = classify_loss.sum()

            alpha = stg.dynamic_alpha(embedding_loss, classify_loss)
            loss = alpha * embedding_loss + (1 - alpha) * classify_loss

            if batch_idx % 100 == 0:
                accuracy = valid_util.running_accuracy(final_pred, label,
                                                       mask_)
                print(
                    'epoch: %d batch %d : loss: %4.6f embed-loss: %4.6f class-loss: %4.6f accuracy: %4.6f'
                    % (epoch, batch_idx, loss.item(), embedding_loss.item(),
                       classify_loss.item(), accuracy))
                if logging:
                    writer.add_scalars(
                        log_name + '_data/loss', {
                            'Total Loss': loss,
                            'Embedding Loss': embedding_loss,
                            'Classify Loss': classify_loss
                        }, epoch * 10 + batch_idx // 1000)
            loss.backward()
            optimizer.step()

        if logging:
            result_dict = valid_util.validate(model,
                                              dm_test_set,
                                              dm_test_dataloader,
                                              mode='report')
            writer.add_scalars(
                log_name + '_data/0-PRF', {
                    '0-Precision': result_dict['0']['precision'],
                    '0-Recall': result_dict['0']['recall'],
                    '0-F1-score': result_dict['0']['f1-score']
                }, epoch)
            writer.add_scalars(
                log_name + '_data/1-PRF', {
                    '1-Precision': result_dict['1']['precision'],
                    '1-Recall': result_dict['1']['recall'],
                    '1-F1-score': result_dict['1']['f1-score']
                }, epoch)
            writer.add_scalar(log_name + '_data/accuracy',
                              result_dict['accuracy'], epoch)

        accuracy, history = valid_util.validate(model,
                                                dm_test_set,
                                                dm_test_dataloader,
                                                mode='detail',
                                                pred_history=history)
        # pickle.dump(history, open('./tmp/e2e_we_history.pkl', 'wb'))
        if accuracy > max_acc:
            max_acc = accuracy
            # torch.save(model.state_dict(), model_save_path)

        # dm_valid_set = pickle.load(open('./tmp/triplet_valid_dataset.pkl', 'rb'))
        # valid_util.validate(model, dm_valid_set, mode='output')

    if logging:
        writer.close()
    print("Max Accuracy: %4.6f" % max_acc)
    return
コード例 #2
0
def train(season_id, dm_train_set, dm_test_set, features, edges):

    EMBEDDING_DIM = 200
    batch_size = 128
    epoch_num = 300
    max_acc = 0
    max_v_acc = 0
    model_save_path = './tmp/model_save/gcn_context.model'

    dm_dataloader = data.DataLoader(dataset=dm_train_set,
                                    batch_size=batch_size,
                                    shuffle=True,
                                    drop_last=True,
                                    num_workers=8)

    dm_test_dataloader = data.DataLoader(dataset=dm_test_set,
                                         batch_size=batch_size,
                                         shuffle=False,
                                         drop_last=False,
                                         num_workers=8)

    model = GCNContext(EMBEDDING_DIM, 256, 200, dropout=0.5)
    print(model)
    model.to(device)

    if torch.cuda.is_available():
        print("CUDA : On")
    else:
        print("CUDA : Off")

    optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.99))
    # scheduler = StepLR(optimizer, step_size=100, gamma=0.1)

    logging = True
    if logging:
        writer = SummaryWriter()
        log_name = 'gcn_context'

    graph = build_graph(features, edges)
    features = torch.FloatTensor(features)
    features = features.to(device)

    for epoch in tqdm(range(epoch_num)):
        print()
        model.train(mode=True)
        # scheduler.step()
        for batch_idx, sample_dict in enumerate(dm_dataloader):
            sentence = torch.LongTensor(sample_dict['sentence'])
            context = torch.LongTensor(sample_dict['context'])
            label = torch.LongTensor(sample_dict['label'])

            sentence = sentence.to(device)
            label = label.to(device)

            optimizer.zero_grad()
            pred = model(sentence, context, graph, features)
            cross_entropy = nn.CrossEntropyLoss()
            loss = cross_entropy(pred, label)
            if batch_idx % 10 == 0:
                accuracy = valid_util.running_accuracy(pred, label)
                print('epoch: %d batch %d : loss: %4.6f accuracy: %4.6f' %
                      (epoch, batch_idx, loss.item(), accuracy))
                if logging:
                    writer.add_scalar(log_name + '_data/loss', loss.item(),
                                      epoch * 10 + batch_idx // 10)
            loss.backward()
            optimizer.step()

        model.eval()
        accuracy = valid_util.validate(model,
                                       dm_test_set,
                                       dm_test_dataloader,
                                       mode='output',
                                       type='graph_context',
                                       features=features,
                                       g=graph)
        if accuracy > max_acc:
            max_acc = accuracy

        if logging:
            result_dict = valid_util.validate(model,
                                              dm_test_set,
                                              dm_test_dataloader,
                                              mode='report',
                                              type='graph_context',
                                              features=features,
                                              g=graph)
            writer.add_scalars(
                log_name + '_data/0-PRF', {
                    '0-Precision': result_dict['0']['precision'],
                    '0-Recall': result_dict['0']['recall'],
                    '0-F1-score': result_dict['0']['f1-score']
                }, epoch)
            writer.add_scalars(
                log_name + '_data/1-PRF', {
                    '1-Precision': result_dict['1']['precision'],
                    '1-Recall': result_dict['1']['recall'],
                    '1-F1-score': result_dict['1']['f1-score']
                }, epoch)
            writer.add_scalars(log_name + '_data/accuracy', {
                'accuracy': result_dict['accuracy'],
                'max_accuracy': max_acc
            }, epoch)

    if logging:
        writer.close()
    print("Max Accuracy: %4.6f" % max_acc)
    return
コード例 #3
0
def train(season_id, dm_train_set, dm_test_set):

    EMBEDDING_DIM = 200
    batch_size = 128
    epoch_num = 100
    max_acc = 0
    max_v_acc = 0
    model_save_path = '.tmp/model_save/straight_embed_context.model'

    dm_dataloader = data.DataLoader(dataset=dm_train_set,
                                    batch_size=batch_size,
                                    shuffle=True,
                                    drop_last=True,
                                    num_workers=8)

    dm_test_dataloader = data.DataLoader(dataset=dm_test_set,
                                         batch_size=batch_size,
                                         shuffle=False,
                                         drop_last=False,
                                         num_workers=8)

    model = EmbeddingContextModeler(dm_train_set.vocab_size(), EMBEDDING_DIM,
                                    dm_train_set.context_words)
    print(model)
    # init_weight = np.loadtxt(os.path.join('./tmp', season_id, 'unigram_weights.txt'))
    # model.init_emb(init_weight)
    model.to(device)

    if torch.cuda.is_available():
        print("CUDA : On")
    else:
        print("CUDA : Off")
    optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.99))
    # scheduler = StepLR(optimizer, step_size=10, gamma=0.8)

    logging = False
    if logging:
        log_name = 'straight_embed_context'
        writer = SummaryWriter()

    for epoch in range(epoch_num):
        model.train(mode=True)
        # scheduler.step()
        for batch_idx, sample_dict in enumerate(dm_dataloader):
            sentence = torch.LongTensor(sample_dict['sentence'])
            label = torch.LongTensor(sample_dict['label'])
            context = torch.LongTensor(sample_dict['context'])
            context_count = sample_dict['context_count'].float()

            sentence = sentence.to(device)
            label = label.to(device)
            context = context.to(device)
            context_count = context_count.to(device)

            optimizer.zero_grad()
            pred = model.forward(sentence, context)
            cross_entropy = nn.CrossEntropyLoss()
            loss = cross_entropy(pred, label)
            if batch_idx % 10 == 0:
                accuracy = valid_util.running_accuracy(pred, label)
                print('epoch: %d batch %d : loss: %4.6f accuracy: %4.6f' %
                      (epoch, batch_idx, loss.item(), accuracy))
                if logging:
                    writer.add_scalar(log_name + '_data/loss', loss.item(),
                                      epoch * 10 + batch_idx // 10)
            loss.backward()
            optimizer.step()

        model.eval()
        accuracy = valid_util.validate(model,
                                       dm_test_set,
                                       dm_test_dataloader,
                                       mode='output',
                                       type='context')
        if accuracy > max_acc:
            max_acc = accuracy

        if logging:
            result_dict = valid_util.validate(model,
                                              dm_test_set,
                                              dm_test_dataloader,
                                              mode='report',
                                              type='context')
            writer.add_scalars(
                log_name + '_data/0-PRF', {
                    '0-Precision': result_dict['0']['precision'],
                    '0-Recall': result_dict['0']['recall'],
                    '0-F1-score': result_dict['0']['f1-score']
                }, epoch)
            writer.add_scalars(
                log_name + '_data/1-PRF', {
                    '1-Precision': result_dict['1']['precision'],
                    '1-Recall': result_dict['1']['recall'],
                    '1-F1-score': result_dict['1']['f1-score']
                }, epoch)
            writer.add_scalars(log_name + '_data/accuracy', {
                'accuracy': result_dict['accuracy'],
                'max_accuracy': max_acc
            }, epoch)

        # dm_valid_set = pickle.load(open(os.path.join('./tmp', season_id, 'unigram_context_valid_dataset.pkl'), 'rb'))
        # v_acc = valid_util.validate(model, dm_valid_set, mode='output', type='context')
        # if v_acc > max_v_acc:
        #     max_v_acc = v_acc

    if logging:
        writer.close()
    print("Max Accuracy: %4.6f" % max_acc)
    # print("Max Validation Accuracy: %4.6f" % max_v_acc)
    return
コード例 #4
0
def train(season_id, dm_train_set, dm_test_set):

    EMBEDDING_DIM = 200
    feature_dim = 50
    max_len = dm_train_set.max_len
    windows_size = [1, 2, 3, 4]
    batch_size = 256
    epoch_num = 50
    fusion_type = 'concat'
    max_acc = 0
    max_v_acc = 0
    model_save_path = './tmp/model_save/pycnn_' + fusion_type + '.model'

    dm_dataloader = data.DataLoader(dataset=dm_train_set,
                                    batch_size=batch_size,
                                    shuffle=True,
                                    drop_last=True,
                                    num_workers=8)

    dm_test_dataloader = data.DataLoader(dataset=dm_test_set,
                                         batch_size=batch_size,
                                         shuffle=False,
                                         drop_last=False,
                                         num_workers=8)

    model = E2ECNNModeler(dm_train_set.vocab_size(),
                          dm_train_set.py_vocab_size(), EMBEDDING_DIM,
                          feature_dim, windows_size, max_len, fusion_type)
    print(model)
    init_weight = np.loadtxt(os.path.join('./tmp', season_id,
                                          'we_weights.txt'))
    model.init_emb(init_weight)
    init_weight = np.loadtxt(os.path.join('./tmp', season_id,
                                          'py_weights.txt'))
    model.init_py_emb(init_weight)
    if torch.cuda.is_available():
        print("CUDA : On")
        model.cuda()
    else:
        print("CUDA : Off")

    embedding_params = list(map(id, model.dynamic_embedding.parameters()))
    embedding_params.extend(list(map(id, model.py_embedding.parameters())))
    other_params = filter(lambda p: id(p) not in embedding_params,
                          model.parameters())

    optimizer = optim.Adam([{
        'params': other_params
    }, {
        'params': model.dynamic_embedding.parameters(),
        'lr': 1e-3
    }, {
        'params': model.py_embedding.parameters(),
        'lr': 1e-3
    }],
                           lr=1e-3,
                           betas=(0.9, 0.99),
                           weight_decay=1e-5)

    logging = True
    if logging:
        writer = SummaryWriter()
        log_name = 'pycnn_' + fusion_type

    history = None

    for epoch in range(epoch_num):

        model.train(mode=True)
        if (epoch + 1) % 5 == 0:
            for param_group in optimizer.param_groups:
                param_group['lr'] = param_group['lr'] * 0.5

        for batch_idx, sample_dict in enumerate(dm_dataloader):
            with autograd.detect_anomaly():
                anchor = Variable(torch.LongTensor(sample_dict['anchor']))
                pos = Variable(torch.LongTensor(sample_dict['pos']))
                neg = Variable(torch.LongTensor(sample_dict['neg']))
                py_anchor = Variable(torch.LongTensor(
                    sample_dict['py_anchor']))
                py_pos = Variable(torch.LongTensor(sample_dict['py_pos']))
                py_neg = Variable(torch.LongTensor(sample_dict['py_neg']))
                label = Variable(torch.LongTensor(sample_dict['label']))
                mask = Variable(torch.LongTensor(sample_dict['mask']))
                mask_ = mask.type(torch.FloatTensor).view(-1)
                if torch.cuda.is_available():
                    anchor = anchor.cuda()
                    pos = pos.cuda()
                    neg = neg.cuda()
                    py_anchor = py_anchor.cuda()
                    py_pos = py_pos.cuda()
                    py_neg = py_neg.cuda()
                    label = label.cuda()
                    mask = mask.cuda()
                    mask_ = mask_.cuda()

                optimizer.zero_grad()
                anchor_embed = model.embed(anchor, py_anchor)
                pos_embed = model.embed(pos, py_pos)
                neg_embed = model.embed(neg, py_neg)
                triplet_loss = nn.TripletMarginLoss(margin=10, p=2)
                embedding_loss = triplet_loss(anchor_embed, pos_embed,
                                              neg_embed)
                anchor_pred = model.forward(anchor, py_anchor).unsqueeze(1)
                pos_pred = model.forward(pos, py_pos).unsqueeze(1)
                neg_pred = model.forward(neg, py_neg).unsqueeze(1)
                final_pred = torch.cat((anchor_pred, pos_pred, neg_pred),
                                       dim=1)
                final_pred = final_pred.view(1, -1, 2)
                final_pred = final_pred.squeeze()

                cross_entropy = nn.NLLLoss(reduction='none')
                label = label.mul(mask)
                label = label.view(-1)
                classify_loss = cross_entropy(F.log_softmax(final_pred, dim=1),
                                              label)
                classify_loss = classify_loss.mul(mask_)

                if mask_.sum() > 0:
                    classify_loss = classify_loss.sum() / mask_.sum()
                else:
                    classify_loss = classify_loss.sum()

                alpha = stg.dynamic_alpha(embedding_loss, classify_loss)
                loss = alpha * embedding_loss + (1 - alpha) * classify_loss

                if batch_idx % 100 == 0:
                    accuracy = valid_util.running_accuracy(
                        final_pred, label, mask_)
                    print(
                        'epoch: %d batch %d : loss: %4.6f embed-loss: %4.6f class-loss: %4.6f accuracy: %4.6f num: %4.1f'
                        %
                        (epoch, batch_idx, loss.item(), embedding_loss.item(),
                         classify_loss.item(), accuracy, mask_.sum()))
                    if logging:
                        writer.add_scalars(
                            log_name + '_data/loss', {
                                'Total Loss': loss,
                                'Embedding Loss': embedding_loss,
                                'Classify Loss': classify_loss
                            }, epoch * 10 + batch_idx // 100)
                loss.backward()
                optimizer.step()

        model.eval()
        if logging:
            result_dict = valid_util.validate(model,
                                              dm_test_set,
                                              dm_test_dataloader,
                                              mode='report',
                                              py=True)
            writer.add_scalars(
                log_name + '_data/0-PRF', {
                    '0-Precision': result_dict['0']['precision'],
                    '0-Recall': result_dict['0']['recall'],
                    '0-F1-score': result_dict['0']['f1-score']
                }, epoch)
            writer.add_scalars(
                log_name + '_data/1-PRF', {
                    '1-Precision': result_dict['1']['precision'],
                    '1-Recall': result_dict['1']['recall'],
                    '1-F1-score': result_dict['1']['f1-score']
                }, epoch)
            writer.add_scalar(log_name + '_data/accuracy',
                              result_dict['accuracy'], epoch)
        accuracy, history = valid_util.validate(model,
                                                dm_test_set,
                                                dm_test_dataloader,
                                                mode='detail',
                                                py=True,
                                                pred_history=history)
        # pickle.dump(history, open('./tmp/e2e_pycnn_history.pkl', 'wb'))
        if accuracy > max_acc:
            max_acc = accuracy
            torch.save(model.state_dict(), model_save_path)

        # dm_valid_set = pickle.load(open('./tmp/triplet_valid_dataset.pkl', 'rb'))
        # v_acc = valid_util.validate(model, dm_valid_set, mode='output', py=True)
        # if v_acc > max_v_acc:
        #     max_v_acc = v_acc

    if logging:
        writer.close()
    print("Max Accuracy: %4.6f" % max_acc)
    print("Max Validation Accuracy: %4.6f" % max_v_acc)
    return
コード例 #5
0
def train(season_id, dm_train_set, dm_test_set):

    EMBEDDING_DIM = 200
    feature_dim = 50
    max_len = 49
    windows_size = [1, 2, 3, 4]
    batch_size = 128
    epoch_num = 100
    max_acc = 0
    max_v_acc = 0
    model_save_path = '.tmp/model_save/straight_CNN.model'

    dm_dataloader = data.DataLoader(dataset=dm_train_set,
                                    batch_size=batch_size,
                                    shuffle=True,
                                    drop_last=True,
                                    num_workers=8)

    dm_test_dataloader = data.DataLoader(dataset=dm_test_set,
                                         batch_size=batch_size,
                                         shuffle=False,
                                         drop_last=False,
                                         num_workers=8)

    model = E2ECNNModeler(dm_train_set.vocab_size(), EMBEDDING_DIM,
                          feature_dim, windows_size, max_len)
    print(model)
    init_weight = np.loadtxt(
        os.path.join('./tmp', season_id, 'unigram_weights.txt'))
    model.init_emb(init_weight)
    if torch.cuda.is_available():
        print("CUDA : On")
        model.cuda()
    else:
        print("CUDA : Off")

    embedding_params = list(map(id, model.dynamic_embedding.parameters()))
    other_params = filter(lambda p: id(p) not in embedding_params,
                          model.parameters())

    optimizer = optim.Adam([{
        'params': other_params
    }, {
        'params': model.dynamic_embedding.parameters(),
        'lr': 1e-3
    }],
                           lr=1e-3,
                           betas=(0.9, 0.99))

    logging = True
    if logging:
        writer = SummaryWriter()
        log_name = 'Direct_CNN'

    history = None

    for epoch in range(epoch_num):

        if epoch > 0:
            for param_group in optimizer.param_groups:
                param_group['lr'] = param_group['lr'] * 0.8

        model.train(mode=True)

        for batch_idx, sample_dict in enumerate(dm_dataloader):
            sentence = Variable(torch.LongTensor(sample_dict['sentence']))
            label = Variable(torch.LongTensor(sample_dict['label']))
            if torch.cuda.is_available():
                sentence = sentence.cuda()
                label = label.cuda()

            optimizer.zero_grad()
            pred = model.forward(sentence)
            cross_entropy = nn.NLLLoss()
            loss = cross_entropy(F.log_softmax(pred, dim=1), label)
            if batch_idx % 10 == 0:
                accuracy = valid_util.running_accuracy(pred, label)
                print('epoch: %d batch %d : loss: %4.6f accuracy: %4.6f' %
                      (epoch, batch_idx, loss.item(), accuracy))
                if logging:
                    writer.add_scalar(log_name + '_data/loss', loss.item(),
                                      epoch * 10 + batch_idx // 10)
            loss.backward()
            optimizer.step()

        model.eval()
        if logging:
            result_dict = valid_util.validate(model,
                                              dm_test_set,
                                              dm_test_dataloader,
                                              mode='report')
            writer.add_scalars(
                log_name + '_data/0-PRF', {
                    '0-Precision': result_dict['0']['precision'],
                    '0-Recall': result_dict['0']['recall'],
                    '0-F1-score': result_dict['0']['f1-score']
                }, epoch)
            writer.add_scalars(
                log_name + '_data/1-PRF', {
                    '1-Precision': result_dict['1']['precision'],
                    '1-Recall': result_dict['1']['recall'],
                    '1-F1-score': result_dict['1']['f1-score']
                }, epoch)
            writer.add_scalar(log_name + '_data/accuracy',
                              result_dict['accuracy'], epoch)
        accuracy = valid_util.validate(model,
                                       dm_test_set,
                                       dm_test_dataloader,
                                       mode='output')
        if accuracy > max_acc:
            max_acc = accuracy

        # dm_valid_set = pickle.load(open(os.path.join('./tmp', season_id, 'unigram_valid_dataset.pkl'), 'rb'))
        # v_acc = valid_util.validate(model, dm_valid_set, mode='output')
        # if v_acc > max_v_acc:
        #     max_v_acc = v_acc

    if logging:
        writer.close()
    print("Max Accuracy: %4.6f" % max_acc)
    print("Max Validation Accuracy: %4.6f" % max_v_acc)
    return
コード例 #6
0
def train(season_id, dm_train_set, dm_test_set, features, edges):

    EMBEDDING_DIM = 200
    batch_size = 128
    epoch_num = 300
    cut_off = 49
    max_acc = 0
    max_v_acc = 0
    model_save_path = './tmp/model_save/gcn_context.model'

    graph = build_graph(features, edges)
    print(graph.num_nodes, graph.num_edges)
    features = torch.FloatTensor(features)
    graph = graph.to(device)

    # dm_valid_set = pickle.load(open(os.path.join('./tmp', season_id, 'unigram_context_valid_dataset.pkl'), 'rb'))
    # dm_valid_set = enhance_dataset(graph, dm_valid_set, cut_off)
    # pickle.dump(dm_valid_set, open(os.path.join('./tmp', season_id, 'unigram_context_valid_dataset.pkl'), 'wb'))

    # dm_train_set = enhance_dataset(graph, dm_train_set, cut_off)
    # dm_test_set = enhance_dataset(graph, dm_test_set, cut_off)
    # pickle.dump(dm_train_set, open(os.path.join('./tmp', season_id, 'unigram_context_train_dataset.pkl'), 'wb'))
    # pickle.dump(dm_test_set, open(os.path.join('./tmp', season_id, 'unigram_context_test_dataset.pkl'), 'wb'))

    dm_dataloader = data.DataLoader(dataset=dm_train_set,
                                    batch_size=batch_size,
                                    shuffle=True,
                                    drop_last=True,
                                    num_workers=8)

    dm_test_dataloader = data.DataLoader(dataset=dm_test_set,
                                         batch_size=batch_size,
                                         shuffle=False,
                                         drop_last=False,
                                         num_workers=8)

    model = GCNContext(graph,
                       EMBEDDING_DIM,
                       256,
                       context_mode='distance',
                       cut_off=cut_off)
    model.init_emb(features)
    print(model)
    model.to(device)

    if torch.cuda.is_available():
        print("CUDA : On")
    else:
        print("CUDA : Off")

    optimizer = optim.Adam(model.parameters(),
                           lr=1e-3,
                           betas=(0.9, 0.99),
                           weight_decay=1e-8)
    scheduler = StepLR(optimizer, step_size=100, gamma=0.1)

    logging = True
    if logging:
        writer = SummaryWriter()
        log_name = 'gcn_context'

    for epoch in tqdm(range(epoch_num)):
        model.train(mode=True)
        scheduler.step()
        for batch_idx, sample_dict in enumerate(dm_dataloader):
            sentence = torch.LongTensor(sample_dict['sentence'])
            context = torch.LongTensor(sample_dict['context'])
            label = torch.LongTensor(sample_dict['label'])
            distance = torch.LongTensor(sample_dict['distance'])

            sentence = sentence.to(device)
            context = context.to(device)
            label = label.to(device)
            distance = distance.to(device)

            optimizer.zero_grad()
            pred = model.forward(sentence, context, distance=distance)
            cross_entropy = nn.CrossEntropyLoss()
            loss = cross_entropy(pred, label)
            if batch_idx % 10 == 0:
                accuracy = valid_util.running_accuracy(pred, label)
                print('epoch: %d batch %d : loss: %4.6f accuracy: %4.6f' %
                      (epoch, batch_idx, loss.item(), accuracy))
                if logging:
                    writer.add_scalar(log_name + '_data/loss', loss.item(),
                                      epoch * 10 + batch_idx // 10)
            loss.backward()
            optimizer.step()

        model.eval()
        accuracy = valid_util.validate(model,
                                       dm_test_set,
                                       dm_test_dataloader,
                                       mode='output',
                                       type='graph_context')
        if accuracy > max_acc:
            max_acc = accuracy

        if logging:
            result_dict = valid_util.validate(model,
                                              dm_test_set,
                                              dm_test_dataloader,
                                              mode='report',
                                              type='graph_context')
            writer.add_scalars(
                log_name + '_data/0-PRF', {
                    '0-Precision': result_dict['0']['precision'],
                    '0-Recall': result_dict['0']['recall'],
                    '0-F1-score': result_dict['0']['f1-score']
                }, epoch)
            writer.add_scalars(
                log_name + '_data/1-PRF', {
                    '1-Precision': result_dict['1']['precision'],
                    '1-Recall': result_dict['1']['recall'],
                    '1-F1-score': result_dict['1']['f1-score']
                }, epoch)
            writer.add_scalars(log_name + '_data/accuracy', {
                'accuracy': result_dict['accuracy'],
                'max_accuracy': max_acc
            }, epoch)

        # v_acc = valid_util.validate(model, dm_valid_set, mode='output', type='graph_context')
        # if v_acc > max_v_acc:
        #     max_v_acc = v_acc

    if logging:
        writer.close()
    print("Max Accuracy: %4.6f" % max_acc)
    # print("Max Validation Accuracy: %4.6f" % max_v_acc)
    return