コード例 #1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42, help='Random seed')
    parser.add_argument('--fast', action='store_true')
    parser.add_argument('--fp16', action='store_true')
    parser.add_argument('-dd',
                        '--data-dir',
                        type=str,
                        required=True,
                        help='Data directory for INRIA sattelite dataset')
    parser.add_argument('-m', '--model', type=str, default='unet', help='')
    parser.add_argument('-b',
                        '--batch-size',
                        type=int,
                        default=8,
                        help='Batch Size during training, e.g. -b 64')
    parser.add_argument('-e',
                        '--epochs',
                        type=int,
                        default=150,
                        help='Epoch to run')
    parser.add_argument('-es',
                        '--early-stopping',
                        type=int,
                        default=None,
                        help='Maximum number of epochs without improvement')
    # parser.add_argument('-f', '--fold', default=None, required=True, type=int, help='Fold to train')
    #     # parser.add_argument('-fe', '--freeze-encoder', type=int, default=0, help='Freeze encoder parameters for N epochs')
    #     # parser.add_argument('-ft', '--fine-tune', action='store_true')
    parser.add_argument('-lr',
                        '--learning-rate',
                        type=float,
                        default=1e-3,
                        help='Initial learning rate')
    parser.add_argument('-l',
                        '--criterion',
                        type=str,
                        default='bce',
                        help='Criterion')
    parser.add_argument('-o',
                        '--optimizer',
                        default='Adam',
                        help='Name of the optimizer')
    parser.add_argument(
        '-c',
        '--checkpoint',
        type=str,
        default=None,
        help='Checkpoint filename to use as initial model weights')
    parser.add_argument('-w',
                        '--workers',
                        default=8,
                        type=int,
                        help='Num workers')

    args = parser.parse_args()
    set_manual_seed(args.seed)

    data_dir = args.data_dir
    num_workers = args.workers
    num_epochs = args.epochs
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    model_name = args.model
    optimizer_name = args.optimizer
    image_size = (512, 512)

    train_loader, valid_loader = get_dataloaders(data_dir=data_dir,
                                                 batch_size=batch_size,
                                                 num_workers=num_workers,
                                                 image_size=image_size,
                                                 fast=args.fast)

    model = maybe_cuda(get_model(model_name, image_size=image_size))
    criterion = get_loss(args.criterion)
    optimizer = get_optimizer(optimizer_name, model.parameters(),
                              learning_rate)

    loaders = collections.OrderedDict()
    loaders["train"] = train_loader
    loaders["valid"] = valid_loader

    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=[10, 20, 40],
                                                     gamma=0.3)

    # model runner
    runner = SupervisedRunner()

    if args.checkpoint:
        checkpoint = UtilsFactory.load_checkpoint(auto_file(args.checkpoint))
        UtilsFactory.unpack_checkpoint(checkpoint, model=model)

        checkpoint_epoch = checkpoint['epoch']
        print('Loaded model weights from', args.checkpoint)
        print('Epoch   :', checkpoint_epoch)
        print('Metrics:', checkpoint['epoch_metrics'])

        # try:
        #     UtilsFactory.unpack_checkpoint(checkpoint, optimizer=optimizer)
        # except Exception as e:
        #     print('Failed to restore optimizer state', e)

        # try:
        #     UtilsFactory.unpack_checkpoint(checkpoint, scheduler=scheduler)
        # except Exception as e:
        #     print('Failed to restore scheduler state', e)

        print('Loaded model weights from', args.checkpoint)

    current_time = datetime.now().strftime('%b%d_%H_%M')
    prefix = f'{current_time}_{args.model}_{args.criterion}'
    log_dir = os.path.join('runs', prefix)
    os.makedirs(log_dir, exist_ok=False)

    print('Train session:', prefix)
    print('\tFast mode  :', args.fast)
    print('\tEpochs     :', num_epochs)
    print('\tWorkers    :', num_workers)
    print('\tData dir   :', data_dir)
    print('\tLog dir    :', log_dir)
    print('\tTrain size :', len(train_loader), len(train_loader.dataset))
    print('\tValid size :', len(valid_loader), len(valid_loader.dataset))
    print('Model:', model_name)
    print('\tParameters:', count_parameters(model))
    print('\tImage size:', image_size)
    print('Optimizer:', optimizer_name)
    print('\tLearning rate:', learning_rate)
    print('\tBatch size   :', batch_size)
    print('\tCriterion    :', args.criterion)

    # model training
    runner.train(
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        callbacks=[
            # OneCycleLR(
            #     cycle_len=num_epochs,
            #     div_factor=10,
            #     increase_fraction=0.3,
            #     momentum_range=(0.95, 0.85)),
            PixelAccuracyMetric(),
            EpochJaccardMetric(),
            ShowPolarBatchesCallback(visualize_inria_predictions,
                                     metric='accuracy',
                                     minimize=False),
            # EarlyStoppingCallback(patience=5, min_delta=0.01, metric='jaccard', minimize=False),
        ],
        loaders=loaders,
        logdir=log_dir,
        num_epochs=num_epochs,
        verbose=True,
        main_metric='jaccard',
        minimize_metric=False,
        state_kwargs={"cmd_args": vars(args)})
コード例 #2
0
def main(model_params, model_name, data_folder, word_embeddings, test_set,
         property_index, save_folder, load_model, result_folder):

    with open(model_params) as f:
        model_params = json.load(f)

    embeddings, word2idx = embedding_utils.load(data_folder + word_embeddings)
    print("Loaded embeddings:", embeddings.shape)

    def check_data(data):
        for g in data:
            if (not 'vertexSet' in g):
                print("vertexSet missed\n")

    print("Reading the property index")
    with open(data_folder + "models/" + model_name + ".property2idx") as f:
        property2idx = ast.literal_eval(f.read())

    max_sent_len = 36
    print("Max sentence length set to: {}".format(max_sent_len))

    graphs_to_indices = sp_models.to_indices_and_entity_pair
    if model_name == "ContextAware":
        graphs_to_indices = sp_models.to_indices_with_real_entities_and_entity_nums_with_vertex_padding_and_entity_pair
    elif model_name == "PCNN":
        graphs_to_indices = sp_models.to_indices_with_relative_positions_and_pcnn_mask_and_entity_pair
    elif model_name == "CNN":
        graphs_to_indices = sp_models.to_indices_with_relative_positions_and_entity_pair
    elif model_name == "GPGNN":
        graphs_to_indices = sp_models.to_indices_with_real_entities_and_entity_nums_with_vertex_padding_and_entity_pair

    _, position2idx = embedding_utils.init_random(np.arange(
        -max_sent_len, max_sent_len),
                                                  1,
                                                  add_all_zeroes=True)

    training_data = None

    n_out = len(property2idx)
    print("N_out:", n_out)

    model = get_model(model_name)(model_params, embeddings, max_sent_len,
                                  n_out).cuda()
    model.load_state_dict(torch.load(save_folder + load_model))
    print("Testing")

    print("Results on the test set")
    test_set, _ = io.load_relation_graphs_from_file(data_folder + test_set)
    test_as_indices = list(
        graphs_to_indices(test_set,
                          word2idx,
                          property2idx,
                          max_sent_len,
                          embeddings=embeddings,
                          position2idx=position2idx))

    print("Start testing!")
    result_file = open(result_folder + "_" + model_name, "w")
    for i in tqdm(
            range(int(test_as_indices[0].shape[0] /
                      model_params['batch_size']))):
        sentence_input = test_as_indices[0][i *
                                            model_params['batch_size']:(i +
                                                                        1) *
                                            model_params['batch_size']]
        entity_markers = test_as_indices[1][i *
                                            model_params['batch_size']:(i +
                                                                        1) *
                                            model_params['batch_size']]
        labels = test_as_indices[2][i * model_params['batch_size']:(i + 1) *
                                    model_params['batch_size']]

        if model_name == "GPGNN":
            output = model(
                Variable(torch.from_numpy(sentence_input.astype(int)),
                         volatile=True).cuda(),
                Variable(torch.from_numpy(entity_markers.astype(int)),
                         volatile=True).cuda(),
                test_as_indices[3][i * model_params['batch_size']:(i + 1) *
                                   model_params['batch_size']])
        elif model_name == "PCNN":
            output = model(
                Variable(torch.from_numpy(sentence_input.astype(int)),
                         volatile=True).cuda(),
                Variable(torch.from_numpy(entity_markers.astype(int)),
                         volatile=True).cuda(),
                Variable(torch.from_numpy(
                    np.array(test_as_indices[3]
                             [i * model_params['batch_size']:(i + 1) *
                              model_params['batch_size']])).float(),
                         requires_grad=False,
                         volatile=True).cuda())
        else:
            output = model(
                Variable(torch.from_numpy(sentence_input.astype(int)),
                         volatile=True).cuda(),
                Variable(torch.from_numpy(entity_markers.astype(int)),
                         volatile=True).cuda())

        score = F.softmax(output)
        score = to_np(score).reshape(-1, n_out)
        labels = labels.reshape(-1)
        p_indices = labels != 0
        score = score[p_indices].tolist()
        labels = labels[p_indices].tolist()
        if (model_name != "LSTM" and model_name != "PCNN"
                and model_name != "CNN"):
            entity_pairs = test_as_indices[-1][i *
                                               model_params['batch_size']:(i +
                                                                           1) *
                                               model_params['batch_size']]
            entity_pairs = reduce(lambda x, y: x + y, entity_pairs)
        else:
            entity_pairs = test_as_indices[-1][i *
                                               model_params['batch_size']:(i +
                                                                           1) *
                                               model_params['batch_size']]
        for (i, j, entity_pair) in zip(score, labels, entity_pairs):
            for index, k in enumerate(i):
                result_file.write(
                    str(index) + "\t" + str(k) + "\t" +
                    str(1 if index == j else 0) + "\t" + entity_pair[0] +
                    "\t" + entity_pair[1] + "\n")
コード例 #3
0
def test():
    """ Main Configurations """
    model_name = "RECON"
    load_model = "RECON-{}.out"  # you should choose the proper model to load
    # device_id = 0

    data_folder = "./data/WikipediaWikidataDistantSupervisionAnnotations.v1.0/enwiki-20160501/"
    save_folder = "./models/RECON/"
    result_folder = "result/"

    model_params = "model_params.json"
    word_embeddings = "./glove.6B/glove.6B.50d.txt"

    test_set = "semantic-graphs-filtered-held-out.02_06.json"

    gat_embedding_file = None
    gat_relation_embedding_file = None
    if "RECON" in model_name:
        context_data_file = "./data/WikipediaWikidataDistantSupervisionAnnotations.v1.0/entities_context.json"
    if "KGGAT" in model_name:
        gat_embedding_file = './models/GAT/WikipediaWikidataDistantSupervisionAnnotations/final_entity_embeddings.json'
        gat_entity2id_file = './data/GAT/WikipediaWikidataDistantSupervisionAnnotations.v1.0/entity2id.txt'
    if model_name == "RECON":
        gat_relation_embedding_file = './re/models/GAT_sep_space/WikipediaWikidataDistantSupervisionAnnotations/final_relation_embeddings.json'
        gat_relation2id_file = './data/GAT_sep_space/WikipediaWikidataDistantSupervisionAnnotations.v1.0/relation2id.txt'
        w_ent2rel_all_rels_file = './re/models/GAT_sep_space/WikipediaWikidataDistantSupervisionAnnotations/W_ent2rel.json.npy'

    use_char_vocab = False

    # a file to store property2idx
    # if is None use model_name.property2idx
    property_index = None

    with open(model_params) as f:
        model_params = json.load(f)
    global args
    save_folder = args.save_folder
    if args.test_file != '':
        test_set = args.test_file
    result_folder = args.result_folder
    model_params['batch_size'] = args.batch_size
    if not os.path.exists(result_folder):
        os.makedirs(result_folder)

    char_vocab_file = os.path.join(save_folder, "char_vocab.json")

    sp_models.set_max_edges(
        model_params['max_num_nodes'] * (model_params['max_num_nodes'] - 1),
        model_params['max_num_nodes'])

    if context_data_file:
        with open(context_data_file, 'r') as f:
            context_data = json.load(f)
    if gat_embedding_file:
        with open(gat_embedding_file, 'r') as f:
            gat_embeddings = json.load(f)
        with open(gat_relation_embedding_file, 'r') as f:
            gat_relation_embeddings = json.load(f)
    if gat_relation_embedding_file:
        W_ent2rel_all_rels = np.load(w_ent2rel_all_rels_file)
        with open(gat_entity2id_file, 'r') as f:
            gat_entity2idx = {}
            data = f.read()
            lines = data.split('\n')
            for line in lines:
                line_arr = line.split(' ')
                if len(line_arr) == 2:
                    gat_entity2idx[line_arr[0].strip()] = line_arr[1].strip()
        with open(gat_relation2id_file, 'r') as f:
            gat_relation2idx = {}
            data = f.read()
            lines = data.split('\n')
            for line in lines:
                line_arr = line.split(' ')
                if len(line_arr) == 2:
                    gat_relation2idx[line_arr[0].strip()] = line_arr[1].strip()

    embeddings, word2idx = embedding_utils.load(word_embeddings)
    print("Loaded embeddings:", embeddings.shape)

    def check_data(data):
        for g in data:
            if (not 'vertexSet' in g):
                print("vertexSet missed\n")

    print("Reading the property index")
    with open(os.path.join(save_folder, model_name + ".property2idx")) as f:
        property2idx = ast.literal_eval(f.read())
    idx2property = {v: k for k, v in property2idx.items()}
    print("Reading the entity index")
    with open(os.path.join(save_folder, model_name + ".entity2idx")) as f:
        entity2idx = ast.literal_eval(f.read())
    idx2entity = {v: k for k, v in entity2idx.items()}
    context_data['ALL_ZERO'] = {
        'desc': '',
        'label': 'ALL_ZERO',
        'instances': [],
        'aliases': []
    }

    with open(char_vocab_file, 'r') as f:
        char_vocab = json.load(f)

    max_sent_len = 36
    print("Max sentence length set to: {}".format(max_sent_len))

    graphs_to_indices = sp_models.to_indices_and_entity_pair
    if model_name == "ContextAware":
        graphs_to_indices = sp_models.to_indices_with_real_entities_and_entity_nums_with_vertex_padding
    elif model_name == "PCNN":
        graphs_to_indices = sp_models.to_indices_with_relative_positions_and_pcnn_mask
    elif model_name == "CNN":
        graphs_to_indices = sp_models.to_indices_with_relative_positions
    elif model_name == "GPGNN":
        graphs_to_indices = sp_models.to_indices_with_real_entities_and_entity_nums_with_vertex_padding
    elif model_name == "RECON-EAC":
        graphs_to_indices = sp_models.to_indices_with_real_entities_and_entity_nums_with_vertex_padding
    elif model_name == "RECON-EAC-KGGAT":
        graphs_to_indices = sp_models.to_indices_with_real_entities_and_entity_nums_with_vertex_padding
    elif model_name == "RECON":
        graphs_to_indices = sp_models.to_indices_with_real_entities_and_entity_nums_with_vertex_padding

    _, position2idx = embedding_utils.init_random(np.arange(
        -max_sent_len, max_sent_len),
                                                  1,
                                                  add_all_zeroes=True)

    training_data = None

    n_out = len(property2idx)
    print("N_out:", n_out)

    if "RECON" not in model_name:
        model = get_model(model_name)(model_params, embeddings, max_sent_len,
                                      n_out)
    elif model_name == "RECON-EAC":
        model = get_model(model_name)(model_params, embeddings, max_sent_len,
                                      n_out, char_vocab)
    elif model_name == "RECON-EAC-KGGAT":
        model = get_model(model_name)(model_params, embeddings, max_sent_len,
                                      n_out, char_vocab)
    elif model_name == "RECON":
        model = get_model(model_name)(model_params, embeddings, max_sent_len,
                                      n_out, char_vocab,
                                      gat_relation_embeddings,
                                      W_ent2rel_all_rels, idx2property,
                                      gat_relation2idx)

    model = model.cuda()
    model.load_state_dict(torch.load(os.path.join(save_folder, load_model)))

    print("Testing")

    print("Results on the test set")
    test_set, _ = io.load_relation_graphs_from_file(data_folder + test_set,
                                                    data='nyt')
    test_as_indices = list(
        graphs_to_indices(test_set,
                          word2idx,
                          property2idx,
                          max_sent_len,
                          embeddings=embeddings,
                          position2idx=position2idx,
                          entity2idx=entity2idx))

    print("Start testing!")
    result_file = open(os.path.join(result_folder, "_" + model_name), "w")
    test_f1 = 0.0
    for i in tqdm(
            range(int(test_as_indices[0].shape[0] /
                      model_params['batch_size']))):
        sentence_input = test_as_indices[0][
            indices[i * model_params['batch_size']:(i + 1) *
                    model_params['batch_size']]]
        entity_markers = test_as_indices[1][
            indices[i * model_params['batch_size']:(i + 1) *
                    model_params['batch_size']]]
        labels = test_as_indices[2][indices[i *
                                            model_params['batch_size']:(i +
                                                                        1) *
                                            model_params['batch_size']]]
        if "RECON" in model_name:
            entity_indices = test_as_indices[4][
                indices[i * model_params['batch_size']:(i + 1) *
                        model_params['batch_size']]]
            unique_entities, unique_entities_surface_forms, max_occurred_entity_in_batch_pos = context_utils.get_batch_unique_entities(
                test_as_indices[4][indices[i *
                                           model_params['batch_size']:(i + 1) *
                                           model_params['batch_size']]],
                test_as_indices[5][indices[i *
                                           model_params['batch_size']:(i + 1) *
                                           model_params['batch_size']]])
            unique_entities_context_indices = context_utils.get_context_indices(
                unique_entities,
                unique_entities_surface_forms,
                context_data,
                idx2entity,
                word2idx,
                char_vocab,
                model_params['conv_filter_size'],
                max_sent_len=32,
                max_num_contexts=32,
                max_char_len=10,
                data='nyt')
            entities_position = context_utils.get_entity_location_unique_entities(
                unique_entities, entity_indices)
        if model_name == "RECON-EAC-KGGAT":
            gat_entity_embeddings = context_utils.get_gat_entity_embeddings(
                entity_indices, entity2idx, idx2entity, gat_entity2idx,
                gat_embeddings)
        elif model_name == "RECON":
            gat_entity_embeddings, nonzero_gat_entity_embeddings, nonzero_entity_pos = context_utils.get_selected_gat_entity_embeddings(
                entity_indices, entity2idx, idx2entity, gat_entity2idx,
                gat_embeddings)

        with torch.no_grad():
            if model_name == "RECON":
                output = model(
                    Variable(torch.from_numpy(
                        sentence_input.astype(int))).cuda(),
                    Variable(torch.from_numpy(
                        entity_markers.astype(int))).cuda(),
                    test_as_indices[3][indices[i *
                                               model_params['batch_size']:(i +
                                                                           1) *
                                               model_params['batch_size']]],
                    Variable(torch.from_numpy(unique_entities.astype(
                        np.long))).cuda(),
                    Variable(torch.from_numpy(entity_indices.astype(
                        np.long))).cuda(),
                    Variable(
                        torch.from_numpy(
                            unique_entities_context_indices[0].astype(
                                np.long))).cuda(),
                    Variable(
                        torch.from_numpy(
                            unique_entities_context_indices[1].astype(
                                np.long))).cuda(),
                    Variable(
                        torch.from_numpy(
                            unique_entities_context_indices[2].astype(
                                bool))).cuda(),
                    Variable(torch.from_numpy(
                        entities_position.astype(int))).cuda(),
                    max_occurred_entity_in_batch_pos,
                    Variable(torch.from_numpy(
                        nonzero_gat_entity_embeddings.astype(np.float32)),
                             requires_grad=False).cuda(), nonzero_entity_pos,
                    Variable(torch.from_numpy(
                        gat_entity_embeddings.astype(np.float32)),
                             requires_grad=False).cuda())
            elif model_name == "RECON-EAC-KGGAT":
                output = model(
                    Variable(torch.from_numpy(
                        sentence_input.astype(int))).cuda(),
                    Variable(torch.from_numpy(
                        entity_markers.astype(int))).cuda(),
                    test_as_indices[3][indices[i *
                                               model_params['batch_size']:(i +
                                                                           1) *
                                               model_params['batch_size']]],
                    Variable(torch.from_numpy(unique_entities.astype(
                        np.long))).cuda(),
                    Variable(torch.from_numpy(entity_indices.astype(
                        np.long))).cuda(),
                    Variable(
                        torch.from_numpy(
                            unique_entities_context_indices[0].astype(
                                np.long))).cuda(),
                    Variable(
                        torch.from_numpy(
                            unique_entities_context_indices[1].astype(
                                np.long))).cuda(),
                    Variable(
                        torch.from_numpy(
                            unique_entities_context_indices[2].astype(
                                bool))).cuda(),
                    Variable(torch.from_numpy(
                        entities_position.astype(int))).cuda(),
                    max_occurred_entity_in_batch_pos,
                    Variable(torch.from_numpy(
                        gat_entity_embeddings.astype(np.float32)),
                             requires_grad=False).cuda())
            elif model_name == "RECON-EAC":
                output = model(
                    Variable(torch.from_numpy(
                        sentence_input.astype(int))).cuda(),
                    Variable(torch.from_numpy(
                        entity_markers.astype(int))).cuda(),
                    test_as_indices[3][indices[i *
                                               model_params['batch_size']:(i +
                                                                           1) *
                                               model_params['batch_size']]],
                    Variable(torch.from_numpy(unique_entities.astype(
                        np.long))).cuda(),
                    Variable(torch.from_numpy(entity_indices.astype(
                        np.long))).cuda(),
                    Variable(
                        torch.from_numpy(
                            unique_entities_context_indices[0].astype(
                                np.long))).cuda(),
                    Variable(
                        torch.from_numpy(
                            unique_entities_context_indices[1].astype(
                                np.long))).cuda(),
                    Variable(
                        torch.from_numpy(
                            unique_entities_context_indices[2].astype(
                                bool))).cuda(),
                    Variable(torch.from_numpy(
                        entities_position.astype(int))).cuda(),
                    max_occurred_entity_in_batch_pos)
            elif model_name == "GPGNN":
                output = model(
                    Variable(torch.from_numpy(
                        sentence_input.astype(int))).cuda(),
                    Variable(torch.from_numpy(
                        entity_markers.astype(int))).cuda(),
                    test_as_indices[3][indices[i *
                                               model_params['batch_size']:(i +
                                                                           1) *
                                               model_params['batch_size']]])
            elif model_name == "PCNN":
                output = model(
                    Variable(torch.from_numpy(
                        sentence_input.astype(int))).cuda(),
                    Variable(torch.from_numpy(
                        entity_markers.astype(int))).cuda(),
                    Variable(torch.from_numpy(
                        np.array(test_as_indices[3]
                                 [i * model_params['batch_size']:(i + 1) *
                                  model_params['batch_size']])).float(),
                             requires_grad=False).cuda())
            else:
                output = model(
                    Variable(torch.from_numpy(
                        sentence_input.astype(int))).cuda(),
                    Variable(torch.from_numpy(
                        entity_markers.astype(int))).cuda())

            _, predicted = torch.max(output, dim=1)
            labels_copy = labels.reshape(-1).tolist()
            predicted = predicted.data.tolist()
            p_indices = np.array(labels_copy) != 0
            predicted = np.array(predicted)[p_indices].tolist()
            labels_copy = np.array(labels_copy)[p_indices].tolist()

            _, _, add_f1 = evaluation_utils.evaluate_instance_based(
                predicted, labels_copy, empty_label=p0_index)
            test_f1 += add_f1

        score = F.softmax(output, dim=-1)
        score = to_np(score).reshape(-1, n_out)
        labels = labels.reshape(-1)
        p_indices = labels != 0
        score = score[p_indices].tolist()
        labels = labels[p_indices].tolist()
        pred_labels = r = np.argmax(score, axis=-1)
        indices = [i for i in range(len(p_indices)) if p_indices[i]]
        if (model_name != "LSTM" and model_name != "PCNN"
                and model_name != "CNN"):
            entity_pairs = test_as_indices[-1][i *
                                               model_params['batch_size']:(i +
                                                                           1) *
                                               model_params['batch_size']]
            entity_pairs = reduce(lambda x, y: x + y, entity_pairs)
        else:
            entity_pairs = test_as_indices[-1][i *
                                               model_params['batch_size']:(i +
                                                                           1) *
                                               model_params['batch_size']]

        start_idx = i * model_params['batch_size']
        for index, (i, j,
                    entity_pair) in enumerate(zip(score, labels,
                                                  entity_pairs)):
            sent = ' '.join(test_set[start_idx + indices[index] //
                                     (model_params['max_num_nodes'] *
                                      (model_params['max_num_nodes'] - 1))]
                            ['tokens']).strip()
            result_file.write("{} | {} | {} | {} | {} | {}\n".format(
                sent, entity_pair[0], entity_pair[1],
                idx2property[pred_labels[index]], idx2property[labels[index]],
                score[index][pred_labels[index]]))

    print(
        "Test f1: ", test_f1 * 1.0 /
        (test_as_indices[0].shape[0] / model_params['batch_size']))
    result_file.close()
コード例 #4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-m', '--model', type=str, default='unet', help='')
    parser.add_argument('-dd',
                        '--data-dir',
                        type=str,
                        default=None,
                        required=True,
                        help='Data dir')
    parser.add_argument(
        '-c',
        '--checkpoint',
        type=str,
        default=None,
        required=True,
        help='Checkpoint filename to use as initial model weights')
    parser.add_argument('-b',
                        '--batch-size',
                        type=int,
                        default=16,
                        help='Batch size for inference')
    parser.add_argument('-tta',
                        '--tta',
                        default=None,
                        type=str,
                        help='Type of TTA to use [fliplr, d4]')
    args = parser.parse_args()

    data_dir = args.data_dir
    checkpoint_file = auto_file(args.checkpoint)
    run_dir = os.path.dirname(os.path.dirname(checkpoint_file))
    out_dir = os.path.join(run_dir, 'evaluation')
    os.makedirs(out_dir, exist_ok=True)

    model = get_model(args.model)

    checkpoint = UtilsFactory.load_checkpoint(checkpoint_file)
    checkpoint_epoch = checkpoint['epoch']
    print('Loaded model weights from', args.checkpoint)
    print('Epoch   :', checkpoint_epoch)
    print('Metrics (Train):', 'IoU:',
          checkpoint['epoch_metrics']['train']['jaccard'], 'Acc:',
          checkpoint['epoch_metrics']['train']['accuracy'])
    print('Metrics (Valid):', 'IoU:',
          checkpoint['epoch_metrics']['valid']['jaccard'], 'Acc:',
          checkpoint['epoch_metrics']['valid']['accuracy'])

    UtilsFactory.unpack_checkpoint(checkpoint, model=model)

    model = model.cuda().eval()

    train_images = find_in_dir(os.path.join(data_dir, 'train', 'images'))
    for fname in tqdm(train_images, total=len(train_images)):
        image = read_rgb_image(fname)
        mask = predict(model,
                       image,
                       tta=args.tta,
                       image_size=(512, 512),
                       batch_size=args.batch_size,
                       activation='sigmoid')
        mask = (mask * 255).astype(np.uint8)
        name = os.path.join(out_dir, os.path.basename(fname))
        cv2.imwrite(name, mask)
コード例 #5
0
    def run(self):

        self.dataset.load()

        X_train, X_test,  y_train, y_test = self.split_dataset()

        logging.info("Train data: {}".format(X_train.shape))
        logging.info("Test data: {}".format(X_test.shape))

        labels = self.dataset.y_cols
        results = None
        k_folds = self.args.kfolds
        scores = []
        best_model = None
        best_score = 0
        cv = KFold(k_folds, random_state=self.args.random_state)

        for k, fold in enumerate(cv.split(X_train, y_train)):

            logging.info('training fold {}'.format(k))
            train, valid = fold
            X_kfold, X_valid = X_train[train], X_train[valid]
            y_kfold, y_valid = y_train[train], y_train[valid]

            model = get_model(self)
            model.train(X_kfold, y_kfold)
            y_pred = model.predict(X_valid)

            score = precision_recall_fscore_support(y_valid, y_pred, average='macro')
            score = score[2] #F1
            scores.append(score)
            print(f"CV {k} F1: {score}")

            if score > best_score:
                best_score = score
                best_model = model

            y_pred = pd.DataFrame(y_pred, columns=labels)
            y_valid = pd.DataFrame(y_valid, columns=labels)
            results_df = y_valid.merge(y_pred, left_index=True,right_index=True, suffixes=('', '_pred'))
            results_df['kfold'] = k
            results_df['set'] = 'cv'
            results_df['id'] = X_valid[:,0]
            results_df['timestamp'] = datetime.now().timestamp()
            results_df['run_id'] = self.args.run_id
            results = results_df if results is None else results.append(results_df, ignore_index=True)

        # predict on tests set
        y_pred = best_model.predict(X_test)


        score = precision_recall_fscore_support(y_test, y_pred, average='macro')
        score = score[2] # f1
        scores = np.array(scores)
        print(f"CV F1: %0.2f (+/- %0.2f)" % (scores.mean(), scores.std() * 2))
        print(f"Test F1: %0.2f" % (score))

        y_pred = pd.DataFrame(y_pred, columns=labels)
        y_test = pd.DataFrame(y_test, columns=labels)

        results_df = y_test.merge(y_pred, left_index=True,right_index=True, suffixes=('', '_pred'))
        results_df['kfold'] = 0
        results_df['set'] = 'test'
        results_df['id'] = X_test[:, 0]
        results_df['timestamp'] = datetime.now().timestamp()
        results_df['run_id'] = self.args.run_id
        results = results.append(results_df, ignore_index=True)
        write_header = not os.path.exists(self.output_path)
        # save results
        with open(self.output_path, 'a') as f:
            results.to_csv(path_or_buf=f, index=False, header= write_header)

        # save hyperparams
        self.args.final_timestamp = datetime.now().timestamp()
        filepath = os.path.splitext(self.output_path)[0] + '.json'
        with open(filepath, 'a') as f:
            config = json.dumps(self.args.__dict__)
            f.write(config +'\r')
コード例 #6
0
def main(config, exp_dir, checkpoint=None):
    torch.manual_seed(config["random_seed"])
    np.random.seed(config["random_seed"])
    random.seed(config["random_seed"])

    logger = Logger(exp_dir)

    device = torch.device("cuda" if config["use_gpu"] else "cpu")

    train_loader, val_loader = get_data_loaders(config, device)

    model = get_model(config["model_name"], **config["model_args"]).to(device)
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=config["learning_rate"],
                                 weight_decay=config["weight_decay"])
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")

    if "load_encoder" in config:
        encoder_model, _ = load_checkpoint(config["load_encoder"], device,
                                           get_model)
        model.encoder = encoder_model.encoder

    if checkpoint:
        logger.log("Resume training..")
        metrics = load_metrics(exp_dir)
        best_val_loss = checkpoint["best_val_loss"]
        i_episode = checkpoint["epoch"]
        model.load_state_dict(checkpoint["state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        scheduler.load_state_dict(checkpoint["scheduler"])
    else:
        i_episode = 0
        best_val_loss = float("inf")
        metrics = {
            "between_eval_time": AverageMeter(),
            "data_time": AverageMeter(),
            "batch_time": AverageMeter(),
            "train_losses": AverageMeter(),
            "train_accs": AverageMeter(),
            "val_time": AverageMeter(),
            "val_batch_time": AverageMeter(),
            "val_data_time": AverageMeter(),
            "val_losses": AverageMeter(),
            "val_accs": AverageMeter()
        }

    keep_training = True
    end = time.time()
    between_eval_end = time.time()
    while keep_training:
        for batch in train_loader:
            metrics["data_time"].update(time.time() - end)
            batch["slide"] = batch["slide"].to(device)

            model.train()
            optimizer.zero_grad()

            scores = compute_loss(config, model, batch, device)
            loss, acc = scores["loss"], scores["accuracy"]

            metrics["train_losses"].update(loss.item())
            metrics["train_accs"].update(acc)

            loss.backward()
            optimizer.step()
            metrics["batch_time"].update(time.time() - end)
            end = time.time()

            del acc
            del loss
            del batch
            if i_episode % config["eval_steps"] == 0:
                val_loss, val_acc = test(config, model, device, val_loader,
                                         metrics)
                scheduler.step(val_loss)

                metrics["between_eval_time"].update(time.time() -
                                                    between_eval_end)

                # Our optimizer has only one parameter group so the first
                # element of our list is our learning rate.
                lr = optimizer.param_groups[0]['lr']
                logger.log(
                    "Episode {0}\n"
                    "Time {metrics[between_eval_time].val:.3f} (data {metrics[data_time].val:.3f} batch {metrics[batch_time].val:.3f}) "
                    "Train loss {metrics[train_losses].val:.4e} ({metrics[train_losses].avg:.4e}) "
                    "Train acc {metrics[train_accs].val:.4f} ({metrics[train_accs].avg:.4f}) "
                    "Learning rate {lr:.2e}\n"
                    "Val time {metrics[val_time].val:.3f} (data {metrics[val_data_time].avg:.3f} batch {metrics[val_batch_time].avg:.3f}) "
                    "Val loss {metrics[val_losses].val:.4e} ({metrics[val_losses].avg:.4e}) "
                    "Val acc {metrics[val_accs].val:.4f} ({metrics[val_accs].avg:.4f}) "
                    .format(i_episode, lr=lr, metrics=metrics))

                save_metrics(metrics, exp_dir)

                is_best = val_loss < best_val_loss
                best_val_loss = val_loss if is_best else best_val_loss
                save_checkpoint(
                    {
                        "epoch": i_episode,
                        "model_name": config["model_name"],
                        "model_args": config["model_args"],
                        "state_dict": model.state_dict(),
                        "best_val_loss": best_val_loss,
                        "optimizer": optimizer.state_dict(),
                        "scheduler": scheduler.state_dict()
                    },
                    is_best,
                    path=exp_dir)
                end = time.time()
                between_eval_end = time.time()
                del val_loss
                del val_acc

            if i_episode >= config["num_episodes"]:
                keep_training = False
                break

            i_episode += 1
コード例 #7
0
ファイル: train.py プロジェクト: xiaoanshi/GP-GNN
def main(model_params, model_name, data_folder, word_embeddings, train_set,
         val_set, property_index, learning_rate, shuffle_data, save_folder,
         save_model, grad_clip):
    if not os.path.exists(save_folder):
        os.mkdir(save_folder)

    with open(model_params) as f:
        model_params = json.load(f)

    embeddings, word2idx = embedding_utils.load(data_folder + word_embeddings)
    print("Loaded embeddings:", embeddings.shape)

    def check_data(data):
        for g in data:
            if (not 'vertexSet' in g):
                print("vertexSet missed\n")

    training_data, _ = io.load_relation_graphs_from_file(data_folder +
                                                         train_set,
                                                         load_vertices=True)

    val_data, _ = io.load_relation_graphs_from_file(data_folder + val_set,
                                                    load_vertices=True)

    check_data(training_data)
    check_data(val_data)

    if property_index:
        print("Reading the property index from parameter")
        with open(data_folder + args.property_index) as f:
            property2idx = ast.literal_eval(f.read())
    else:
        _, property2idx = embedding_utils.init_random(
            {e["kbID"]
             for g in training_data for e in g["edgeSet"]} | {"P0"},
            1,
            add_all_zeroes=True,
            add_unknown=True)

    max_sent_len = max(len(g["tokens"]) for g in training_data)
    print("Max sentence length:", max_sent_len)

    max_sent_len = 36
    print("Max sentence length set to: {}".format(max_sent_len))

    graphs_to_indices = sp_models.to_indices
    if model_name == "ContextAware":
        graphs_to_indices = sp_models.to_indices_with_real_entities_and_entity_nums_with_vertex_padding
    elif model_name == "PCNN":
        graphs_to_indices = sp_models.to_indices_with_relative_positions_and_pcnn_mask
    elif model_name == "CNN":
        graphs_to_indices = sp_models.to_indices_with_relative_positions
    elif model_name == "GPGNN":
        graphs_to_indices = sp_models.to_indices_with_real_entities_and_entity_nums_with_vertex_padding

    _, position2idx = embedding_utils.init_random(np.arange(
        -max_sent_len, max_sent_len),
                                                  1,
                                                  add_all_zeroes=True)

    train_as_indices = list(
        graphs_to_indices(training_data,
                          word2idx,
                          property2idx,
                          max_sent_len,
                          embeddings=embeddings,
                          position2idx=position2idx))

    training_data = None

    n_out = len(property2idx)
    print("N_out:", n_out)

    val_as_indices = list(
        graphs_to_indices(val_data,
                          word2idx,
                          property2idx,
                          max_sent_len,
                          embeddings=embeddings,
                          position2idx=position2idx))
    val_data = None

    print("Save property dictionary.")
    with open(data_folder + "models/" + model_name + ".property2idx",
              'w') as outfile:
        outfile.write(str(property2idx))

    print("Training the model")

    print("Initialize the model")
    model = get_model(model_name)(model_params, embeddings, max_sent_len,
                                  n_out).cuda()

    loss_func = nn.CrossEntropyLoss(ignore_index=0).cuda()
    opt = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=learning_rate,
                           weight_decay=model_params['weight_decay'])

    indices = np.arange(train_as_indices[0].shape[0])

    step = 0
    for train_epoch in range(model_params['nb_epoch']):
        if (shuffle_data):
            np.random.shuffle(indices)
        f1 = 0
        for i in tqdm(
                range(
                    int(train_as_indices[0].shape[0] /
                        model_params['batch_size']))):
            opt.zero_grad()

            sentence_input = train_as_indices[0][
                indices[i * model_params['batch_size']:(i + 1) *
                        model_params['batch_size']]]
            entity_markers = train_as_indices[1][
                indices[i * model_params['batch_size']:(i + 1) *
                        model_params['batch_size']]]
            labels = train_as_indices[2][
                indices[i * model_params['batch_size']:(i + 1) *
                        model_params['batch_size']]]

            if model_name == "GPGNN":
                output = model(
                    Variable(torch.from_numpy(
                        sentence_input.astype(int))).cuda(),
                    Variable(torch.from_numpy(
                        entity_markers.astype(int))).cuda(),
                    train_as_indices[3][
                        indices[i * model_params['batch_size']:(i + 1) *
                                model_params['batch_size']]])
            elif model_name == "PCNN":
                output = model(
                    Variable(torch.from_numpy(
                        sentence_input.astype(int))).cuda(),
                    Variable(torch.from_numpy(
                        entity_markers.astype(int))).cuda(),
                    Variable(torch.from_numpy(
                        np.array(train_as_indices[3]
                                 [i * model_params['batch_size']:(i + 1) *
                                  model_params['batch_size']])).float(),
                             requires_grad=False).cuda())
            else:
                output = model(
                    Variable(torch.from_numpy(
                        sentence_input.astype(int))).cuda(),
                    Variable(torch.from_numpy(
                        entity_markers.astype(int))).cuda())

            loss = loss_func(
                output,
                Variable(torch.from_numpy(labels.astype(int))).view(-1).cuda())

            loss.backward()
            torch.nn.utils.clip_grad_norm(model.parameters(), grad_clip)
            opt.step()

            _, predicted = torch.max(output, dim=1)
            labels = labels.reshape(-1).tolist()
            predicted = predicted.data.tolist()
            p_indices = np.array(labels) != 0
            predicted = np.array(predicted)[p_indices].tolist()
            labels = np.array(labels)[p_indices].tolist()

            _, _, add_f1 = evaluation_utils.evaluate_instance_based(
                predicted, labels, empty_label=p0_index)
            f1 += add_f1

        print("Train f1: ",
              f1 / (train_as_indices[0].shape[0] / model_params['batch_size']))

        val_f1 = 0
        for i in tqdm(
                range(
                    int(val_as_indices[0].shape[0] /
                        model_params['batch_size']))):
            sentence_input = val_as_indices[0][i *
                                               model_params['batch_size']:(i +
                                                                           1) *
                                               model_params['batch_size']]
            entity_markers = val_as_indices[1][i *
                                               model_params['batch_size']:(i +
                                                                           1) *
                                               model_params['batch_size']]
            labels = val_as_indices[2][i * model_params['batch_size']:(i + 1) *
                                       model_params['batch_size']]
            if model_name == "GPGNN":
                output = model(
                    Variable(torch.from_numpy(sentence_input.astype(int)),
                             volatile=True).cuda(),
                    Variable(torch.from_numpy(entity_markers.astype(int)),
                             volatile=True).cuda(),
                    val_as_indices[3][i * model_params['batch_size']:(i + 1) *
                                      model_params['batch_size']])
            elif model_name == "PCNN":
                output = model(
                    Variable(torch.from_numpy(sentence_input.astype(int)),
                             volatile=True).cuda(),
                    Variable(torch.from_numpy(entity_markers.astype(int)),
                             volatile=True).cuda(),
                    Variable(torch.from_numpy(
                        np.array(val_as_indices[3]
                                 [i * model_params['batch_size']:(i + 1) *
                                  model_params['batch_size']])).float(),
                             volatile=True).cuda())
            else:
                output = model(
                    Variable(torch.from_numpy(sentence_input.astype(int)),
                             volatile=True).cuda(),
                    Variable(torch.from_numpy(entity_markers.astype(int)),
                             volatile=True).cuda())

            _, predicted = torch.max(output, dim=1)
            labels = labels.reshape(-1).tolist()
            predicted = predicted.data.tolist()
            p_indices = np.array(labels) != 0
            predicted = np.array(predicted)[p_indices].tolist()
            labels = np.array(labels)[p_indices].tolist()

            _, _, add_f1 = evaluation_utils.evaluate_instance_based(
                predicted, labels, empty_label=p0_index)
            val_f1 += add_f1
        print(
            "Validation f1: ",
            val_f1 / (val_as_indices[0].shape[0] / model_params['batch_size']))

        # save model
        if (train_epoch % 5 == 0 and save_model):
            torch.save(
                model.state_dict(),
                "{0}{1}-{2}.out".format(save_folder, model_name,
                                        str(train_epoch)))

        step = step + 1
コード例 #8
0
ファイル: train.py プロジェクト: ansonb/RECON
def train():
    """ Main Configurations """
    model_name = "RECON"
    data_folder = "./data/WikipediaWikidataDistantSupervisionAnnotations.v1.0/enwiki-20160501/"
    save_folder = "./models/RECON/"

    model_params = "model_params.json"
    word_embeddings = "glove.6B.50d.txt"
    train_set = "semantic-graphs-filtered-training.02_06.json"
    val_set = "semantic-graphs-filtered-validation.02_06.json"

    use_char_vocab = False

    gat_embedding_file = None
    gat_relation_embedding_file = None
    # Enter the appropriate file paths here
    if "RECON" in model_name:
        context_data_file = "./data/WikipediaWikidataDistantSupervisionAnnotations.v1.0/entities_context.json"
    if "KGGAT" in model_name:
        gat_embedding_file = './models/GAT/WikipediaWikidataDistantSupervisionAnnotations/final_entity_embeddings.json'
        gat_entity2id_file = './data/GAT/WikipediaWikidataDistantSupervisionAnnotations.v1.0/entity2id.txt'
    if model_name == "RECON":
        # Point to the trained model/embedding/data files
        gat_relation_embedding_file = './models/GAT/WikipediaWikidataDistantSupervisionAnnotations/final_relation_embeddings.json'
        gat_relation2id_file = './data/GAT/WikipediaWikidataDistantSupervisionAnnotations.v1.0/relation2id.txt'
        w_ent2rel_all_rels_file = './models/GAT/WikipediaWikidataDistantSupervisionAnnotations/W_ent2rel.json.npy'

    # a file to store property2idx
    # if is None use model_name.property2idx
    property_index = None
    learning_rate = 1e-3
    shuffle_data = True
    save_model = True
    grad_clip = 0.25
    # os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)

    with open(model_params) as f:
        model_params = json.load(f)
    global args
    save_folder = args.save_folder
    model_params['batch_size'] = args.batch_size
    model_params['nb_epoch'] = args.epochs
    val_results_file = os.path.join(save_folder, 'val_results.json')

    char_vocab_file = os.path.join(save_folder, "char_vocab.json")

    if not os.path.exists(save_folder):
        os.mkdir(save_folder)

    sp_models.set_max_edges(
        model_params['max_num_nodes'] * (model_params['max_num_nodes'] - 1),
        model_params['max_num_nodes'])

    if context_data_file:
        with open(context_data_file, 'r') as f:
            context_data = json.load(f)
    if gat_embedding_file:
        with open(gat_embedding_file, 'r') as f:
            gat_embeddings = json.load(f)
        with open(gat_relation_embedding_file, 'r') as f:
            gat_relation_embeddings = json.load(f)
    if gat_relation_embedding_file:
        W_ent2rel_all_rels = np.load(w_ent2rel_all_rels_file)
        with open(gat_entity2id_file, 'r') as f:
            gat_entity2idx = {}
            data = f.read()
            lines = data.split('\n')
            for line in lines:
                line_arr = line.split(' ')
                if len(line_arr) == 2:
                    gat_entity2idx[line_arr[0].strip()] = line_arr[1].strip()
        with open(gat_relation2id_file, 'r') as f:
            gat_relation2idx = {}
            data = f.read()
            lines = data.split('\n')
            for line in lines:
                line_arr = line.split(' ')
                if len(line_arr) == 2:
                    gat_relation2idx[line_arr[0].strip()] = line_arr[1].strip()

    embeddings, word2idx = embedding_utils.load(data_folder + word_embeddings)
    print("Loaded embeddings:", embeddings.shape)

    def check_data(data):
        for g in data:
            if (not 'vertexSet' in g):
                print("vertexSet missed\n")

    training_data, _ = io.load_relation_graphs_from_file(data_folder +
                                                         train_set,
                                                         load_vertices=True,
                                                         data='nyt')
    if not use_char_vocab:
        char_vocab = context_utils.make_char_vocab(training_data)
        print("Save char vocab dictionary.")
        with open(char_vocab_file, 'w') as outfile:
            json.dump(char_vocab, outfile, indent=4)
    else:
        with open(char_vocab_file, 'r') as f:
            char_vocab = json.load(f)

    val_data, _ = io.load_relation_graphs_from_file(data_folder + val_set,
                                                    load_vertices=True,
                                                    data="nyt")

    check_data(training_data)
    check_data(val_data)

    if property_index:
        print("Reading the property index from parameter")
        with open(data_folder + args.property_index) as f:
            property2idx = ast.literal_eval(f.read())
        with open(data_folder + args.entity_index) as f:
            entity2idx = ast.literal_eval(f.read())
    else:
        _, property2idx = embedding_utils.init_random(
            {e["kbID"]
             for g in training_data for e in g["edgeSet"]} | {"P0"},
            1,
            add_all_zeroes=True,
            add_unknown=True)
        _, entity2idx = context_utils.init_random(
            {kbID
             for kbID, _ in context_data.items()},
            model_params['embedding_dim'],
            add_all_zeroes=True,
            add_unknown=True)
    idx2entity = {v: k for k, v in entity2idx.items()}
    context_data['ALL_ZERO'] = {
        'desc': '',
        'label': 'ALL_ZERO',
        'instances': [],
        'aliases': []
    }

    max_sent_len = max(len(g["tokens"]) for g in training_data)
    print("Max sentence length:", max_sent_len)

    max_sent_len = 36
    print("Max sentence length set to: {}".format(max_sent_len))

    graphs_to_indices = sp_models.to_indices
    if model_name == "ContextAware":
        graphs_to_indices = sp_models.to_indices_with_real_entities_and_entity_nums_with_vertex_padding
    elif model_name == "PCNN":
        graphs_to_indices = sp_models.to_indices_with_relative_positions_and_pcnn_mask
    elif model_name == "CNN":
        graphs_to_indices = sp_models.to_indices_with_relative_positions
    elif model_name == "GPGNN":
        graphs_to_indices = sp_models.to_indices_with_real_entities_and_entity_nums_with_vertex_padding
    elif model_name == "RECON-EAC":
        graphs_to_indices = sp_models.to_indices_with_real_entities_and_entity_nums_with_vertex_padding
    elif model_name == "RECON-EAC-KGGAT":
        graphs_to_indices = sp_models.to_indices_with_real_entities_and_entity_nums_with_vertex_padding
    elif model_name == "RECON":
        graphs_to_indices = sp_models.to_indices_with_real_entities_and_entity_nums_with_vertex_padding

    _, position2idx = embedding_utils.init_random(np.arange(
        -max_sent_len, max_sent_len),
                                                  1,
                                                  add_all_zeroes=True)

    train_as_indices = list(
        graphs_to_indices(training_data,
                          word2idx,
                          property2idx,
                          max_sent_len,
                          embeddings=embeddings,
                          position2idx=position2idx,
                          entity2idx=entity2idx))

    training_data = None

    n_out = len(property2idx)
    print("N_out:", n_out)

    val_as_indices = list(
        graphs_to_indices(val_data,
                          word2idx,
                          property2idx,
                          max_sent_len,
                          embeddings=embeddings,
                          position2idx=position2idx,
                          entity2idx=entity2idx))
    val_data = None

    print("Save property dictionary.")
    with open(os.path.join(save_folder, model_name + ".property2idx"),
              'w') as outfile:
        outfile.write(str(property2idx))
    print("Save entity dictionary.")
    with open(os.path.join(save_folder, model_name + ".entity2idx"),
              'w') as outfile:
        outfile.write(str(entity2idx))

    print("Training the model")

    print("Initialize the model")

    if "RECON" not in model_name:
        model = get_model(model_name)(model_params, embeddings, max_sent_len,
                                      n_out)
    elif model_name == "RECON-EAC":
        model = get_model(model_name)(model_params, embeddings, max_sent_len,
                                      n_out, char_vocab)
    elif model_name == "RECON-EAC-KGGAT":
        model = get_model(model_name)(model_params, embeddings, max_sent_len,
                                      n_out, char_vocab)
    elif model_name == "RECON":
        model = get_model(model_name)(model_params, embeddings, max_sent_len,
                                      n_out, char_vocab,
                                      gat_relation_embeddings,
                                      W_ent2rel_all_rels, idx2property,
                                      gat_relation2idx)

    model = model.cuda()
    loss_func = nn.CrossEntropyLoss(ignore_index=0).cuda()

    opt = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                  model.parameters()),
                           lr=learning_rate,
                           weight_decay=model_params['weight_decay'])

    indices = np.arange(train_as_indices[0].shape[0])

    step = 0
    val_results = []
    for train_epoch in range(model_params['nb_epoch']):
        if (shuffle_data):
            np.random.shuffle(indices)
        f1 = 0
        for i in tqdm(
                range(
                    int(train_as_indices[0].shape[0] /
                        model_params['batch_size']))):
            opt.zero_grad()

            sentence_input = train_as_indices[0][
                indices[i * model_params['batch_size']:(i + 1) *
                        model_params['batch_size']]]
            entity_markers = train_as_indices[1][
                indices[i * model_params['batch_size']:(i + 1) *
                        model_params['batch_size']]]
            labels = train_as_indices[2][
                indices[i * model_params['batch_size']:(i + 1) *
                        model_params['batch_size']]]
            if "RECON" in model_name:
                entity_indices = train_as_indices[4][
                    indices[i * model_params['batch_size']:(i + 1) *
                            model_params['batch_size']]]
                unique_entities, unique_entities_surface_forms, max_occurred_entity_in_batch_pos = context_utils.get_batch_unique_entities(
                    train_as_indices[4]
                    [indices[i * model_params['batch_size']:(i + 1) *
                             model_params['batch_size']]], train_as_indices[5]
                    [indices[i * model_params['batch_size']:(i + 1) *
                             model_params['batch_size']]])
                unique_entities_context_indices = context_utils.get_context_indices(
                    unique_entities,
                    unique_entities_surface_forms,
                    context_data,
                    idx2entity,
                    word2idx,
                    char_vocab,
                    model_params['conv_filter_size'],
                    max_sent_len=32,
                    max_num_contexts=32,
                    max_char_len=10,
                    data='nyt')
                entities_position = context_utils.get_entity_location_unique_entities(
                    unique_entities, entity_indices)
            if model_name == "RECON-EAC-KGGAT":
                gat_entity_embeddings = context_utils.get_gat_entity_embeddings(
                    entity_indices, entity2idx, idx2entity, gat_entity2idx,
                    gat_embeddings)
            elif model_name == "RECON":
                gat_entity_embeddings, nonzero_gat_entity_embeddings, nonzero_entity_pos = context_utils.get_selected_gat_entity_embeddings(
                    entity_indices, entity2idx, idx2entity, gat_entity2idx,
                    gat_embeddings)

            if model_name == "RECON":
                output = model(
                    Variable(torch.from_numpy(
                        sentence_input.astype(int))).cuda(),
                    Variable(torch.from_numpy(
                        entity_markers.astype(int))).cuda(),
                    train_as_indices[3][
                        indices[i * model_params['batch_size']:(i + 1) *
                                model_params['batch_size']]],
                    Variable(torch.from_numpy(unique_entities.astype(
                        np.long))).cuda(),
                    Variable(torch.from_numpy(entity_indices.astype(
                        np.long))).cuda(),
                    Variable(
                        torch.from_numpy(
                            unique_entities_context_indices[0].astype(
                                np.long))).cuda(),
                    Variable(
                        torch.from_numpy(
                            unique_entities_context_indices[1].astype(
                                np.long))).cuda(),
                    Variable(
                        torch.from_numpy(
                            unique_entities_context_indices[2].astype(
                                bool))).cuda(),
                    Variable(torch.from_numpy(
                        entities_position.astype(int))).cuda(),
                    max_occurred_entity_in_batch_pos,
                    Variable(torch.from_numpy(
                        nonzero_gat_entity_embeddings.astype(np.float32)),
                             requires_grad=False).cuda(), nonzero_entity_pos,
                    Variable(torch.from_numpy(
                        gat_entity_embeddings.astype(np.float32)),
                             requires_grad=False).cuda())
            elif model_name == "RECON-EAC-KGGAT":
                output = model(
                    Variable(torch.from_numpy(
                        sentence_input.astype(int))).cuda(),
                    Variable(torch.from_numpy(
                        entity_markers.astype(int))).cuda(),
                    train_as_indices[3][
                        indices[i * model_params['batch_size']:(i + 1) *
                                model_params['batch_size']]],
                    Variable(torch.from_numpy(unique_entities.astype(
                        np.long))).cuda(),
                    Variable(torch.from_numpy(entity_indices.astype(
                        np.long))).cuda(),
                    Variable(
                        torch.from_numpy(
                            unique_entities_context_indices[0].astype(
                                np.long))).cuda(),
                    Variable(
                        torch.from_numpy(
                            unique_entities_context_indices[1].astype(
                                np.long))).cuda(),
                    Variable(
                        torch.from_numpy(
                            unique_entities_context_indices[2].astype(
                                bool))).cuda(),
                    Variable(torch.from_numpy(
                        entities_position.astype(int))).cuda(),
                    max_occurred_entity_in_batch_pos,
                    Variable(torch.from_numpy(
                        gat_entity_embeddings.astype(np.float32)),
                             requires_grad=False).cuda())
            elif model_name == "RECON-EAC":
                output = model(
                    Variable(torch.from_numpy(
                        sentence_input.astype(int))).cuda(),
                    Variable(torch.from_numpy(
                        entity_markers.astype(int))).cuda(),
                    train_as_indices[3][
                        indices[i * model_params['batch_size']:(i + 1) *
                                model_params['batch_size']]],
                    Variable(torch.from_numpy(unique_entities.astype(
                        np.long))).cuda(),
                    Variable(torch.from_numpy(entity_indices.astype(
                        np.long))).cuda(),
                    Variable(
                        torch.from_numpy(
                            unique_entities_context_indices[0].astype(
                                np.long))).cuda(),
                    Variable(
                        torch.from_numpy(
                            unique_entities_context_indices[1].astype(
                                np.long))).cuda(),
                    Variable(
                        torch.from_numpy(
                            unique_entities_context_indices[2].astype(
                                bool))).cuda(),
                    Variable(torch.from_numpy(
                        entities_position.astype(int))).cuda(),
                    max_occurred_entity_in_batch_pos)
            elif model_name == "GPGNN":
                output = model(
                    Variable(torch.from_numpy(
                        sentence_input.astype(int))).cuda(),
                    Variable(torch.from_numpy(
                        entity_markers.astype(int))).cuda(),
                    train_as_indices[3][
                        indices[i * model_params['batch_size']:(i + 1) *
                                model_params['batch_size']]])
            elif model_name == "PCNN":
                output = model(
                    Variable(torch.from_numpy(
                        sentence_input.astype(int))).cuda(),
                    Variable(torch.from_numpy(
                        entity_markers.astype(int))).cuda(),
                    Variable(torch.from_numpy(
                        np.array(train_as_indices[3]
                                 [i * model_params['batch_size']:(i + 1) *
                                  model_params['batch_size']])).float(),
                             requires_grad=False).cuda())
            else:
                output = model(
                    Variable(torch.from_numpy(
                        sentence_input.astype(int))).cuda(),
                    Variable(torch.from_numpy(
                        entity_markers.astype(int))).cuda())

            loss = loss_func(
                output,
                Variable(torch.from_numpy(labels.astype(int))).view(-1).cuda())

            loss.backward()
            torch.nn.utils.clip_grad_norm(model.parameters(), grad_clip)
            opt.step()

            _, predicted = torch.max(output, dim=1)
            labels = labels.reshape(-1).tolist()
            predicted = predicted.data.tolist()
            p_indices = np.array(labels) != 0
            predicted = np.array(predicted)[p_indices].tolist()
            labels = np.array(labels)[p_indices].tolist()

            _, _, add_f1 = evaluation_utils.evaluate_instance_based(
                predicted, labels, empty_label=p0_index)
            f1 += add_f1

        train_f1 = f1 / (train_as_indices[0].shape[0] /
                         model_params['batch_size'])
        print("Train f1: ", train_f1)

        val_f1 = 0
        for i in tqdm(
                range(
                    int(val_as_indices[0].shape[0] /
                        model_params['batch_size']))):
            sentence_input = val_as_indices[0][i *
                                               model_params['batch_size']:(i +
                                                                           1) *
                                               model_params['batch_size']]
            entity_markers = val_as_indices[1][i *
                                               model_params['batch_size']:(i +
                                                                           1) *
                                               model_params['batch_size']]
            labels = val_as_indices[2][i * model_params['batch_size']:(i + 1) *
                                       model_params['batch_size']]
            if "RECON" in model_name:
                entity_indices = val_as_indices[4][i *
                                                   model_params['batch_size']:
                                                   (i + 1) *
                                                   model_params['batch_size']]
                unique_entities, unique_entities_surface_forms, max_occurred_entity_in_batch_pos = context_utils.get_batch_unique_entities(
                    val_as_indices[4][i * model_params['batch_size']:(i + 1) *
                                      model_params['batch_size']],
                    val_as_indices[5][i * model_params['batch_size']:(i + 1) *
                                      model_params['batch_size']])
                unique_entities_context_indices = context_utils.get_context_indices(
                    unique_entities,
                    unique_entities_surface_forms,
                    context_data,
                    idx2entity,
                    word2idx,
                    char_vocab,
                    model_params['conv_filter_size'],
                    max_sent_len=32,
                    max_num_contexts=32,
                    max_char_len=10,
                    data='nyt')
                entities_position = context_utils.get_entity_location_unique_entities(
                    unique_entities, entity_indices)
            if model_name == 'RECON-EAC-KGGAT':
                gat_entity_embeddings = context_utils.get_gat_entity_embeddings(
                    entity_indices, entity2idx, idx2entity, gat_entity2idx,
                    gat_embeddings)
            elif model_name == "RECON":
                gat_entity_embeddings, nonzero_gat_entity_embeddings, nonzero_entity_pos = context_utils.get_selected_gat_entity_embeddings(
                    entity_indices, entity2idx, idx2entity, gat_entity2idx,
                    gat_embeddings)

            if model_name == "RECON":
                output = model(
                    Variable(torch.from_numpy(
                        sentence_input.astype(int))).cuda(),
                    Variable(torch.from_numpy(
                        entity_markers.astype(int))).cuda(),
                    train_as_indices[3][
                        indices[i * model_params['batch_size']:(i + 1) *
                                model_params['batch_size']]],
                    Variable(torch.from_numpy(unique_entities.astype(
                        np.long))).cuda(),
                    Variable(torch.from_numpy(entity_indices.astype(
                        np.long))).cuda(),
                    Variable(
                        torch.from_numpy(
                            unique_entities_context_indices[0].astype(
                                np.long))).cuda(),
                    Variable(
                        torch.from_numpy(
                            unique_entities_context_indices[1].astype(
                                np.long))).cuda(),
                    Variable(
                        torch.from_numpy(
                            unique_entities_context_indices[2].astype(
                                bool))).cuda(),
                    Variable(torch.from_numpy(
                        entities_position.astype(int))).cuda(),
                    max_occurred_entity_in_batch_pos,
                    Variable(torch.from_numpy(
                        nonzero_gat_entity_embeddings.astype(np.float32)),
                             requires_grad=False).cuda(), nonzero_entity_pos,
                    Variable(torch.from_numpy(
                        gat_entity_embeddings.astype(np.float32)),
                             requires_grad=False).cuda())
            elif model_name == "RECON-EAC-KGGAT":
                output = model(
                    Variable(torch.from_numpy(
                        sentence_input.astype(int))).cuda(),
                    Variable(torch.from_numpy(
                        entity_markers.astype(int))).cuda(),
                    train_as_indices[3][
                        indices[i * model_params['batch_size']:(i + 1) *
                                model_params['batch_size']]],
                    Variable(torch.from_numpy(unique_entities.astype(
                        np.long))).cuda(),
                    Variable(torch.from_numpy(entity_indices.astype(
                        np.long))).cuda(),
                    Variable(
                        torch.from_numpy(
                            unique_entities_context_indices[0].astype(
                                np.long))).cuda(),
                    Variable(
                        torch.from_numpy(
                            unique_entities_context_indices[1].astype(
                                np.long))).cuda(),
                    Variable(
                        torch.from_numpy(
                            unique_entities_context_indices[2].astype(
                                bool))).cuda(),
                    Variable(torch.from_numpy(
                        entities_position.astype(int))).cuda(),
                    max_occurred_entity_in_batch_pos,
                    Variable(torch.from_numpy(
                        gat_entity_embeddings.astype(np.float32)),
                             requires_grad=False).cuda())
            elif model_name == "RECON-EAC":
                output = model(
                    Variable(torch.from_numpy(
                        sentence_input.astype(int))).cuda(),
                    Variable(torch.from_numpy(
                        entity_markers.astype(int))).cuda(),
                    train_as_indices[3][
                        indices[i * model_params['batch_size']:(i + 1) *
                                model_params['batch_size']]],
                    Variable(torch.from_numpy(unique_entities.astype(
                        np.long))).cuda(),
                    Variable(torch.from_numpy(entity_indices.astype(
                        np.long))).cuda(),
                    Variable(
                        torch.from_numpy(
                            unique_entities_context_indices[0].astype(
                                np.long))).cuda(),
                    Variable(
                        torch.from_numpy(
                            unique_entities_context_indices[1].astype(
                                np.long))).cuda(),
                    Variable(
                        torch.from_numpy(
                            unique_entities_context_indices[2].astype(
                                bool))).cuda(),
                    Variable(torch.from_numpy(
                        entities_position.astype(int))).cuda(),
                    max_occurred_entity_in_batch_pos)
            elif model_name == "GPGNN":
                output = model(
                    Variable(torch.from_numpy(
                        sentence_input.astype(int))).cuda(),
                    Variable(torch.from_numpy(
                        entity_markers.astype(int))).cuda(),
                    train_as_indices[3][
                        indices[i * model_params['batch_size']:(i + 1) *
                                model_params['batch_size']]])
            elif model_name == "PCNN":
                output = model(
                    Variable(torch.from_numpy(
                        sentence_input.astype(int))).cuda(),
                    Variable(torch.from_numpy(
                        entity_markers.astype(int))).cuda(),
                    Variable(torch.from_numpy(
                        np.array(train_as_indices[3]
                                 [i * model_params['batch_size']:(i + 1) *
                                  model_params['batch_size']])).float(),
                             requires_grad=False).cuda())
            else:
                output = model(
                    Variable(torch.from_numpy(
                        sentence_input.astype(int))).cuda(),
                    Variable(torch.from_numpy(
                        entity_markers.astype(int))).cuda())

            _, predicted = torch.max(output, dim=1)
            labels = labels.reshape(-1).tolist()
            predicted = predicted.data.tolist()
            p_indices = np.array(labels) != 0
            predicted = np.array(predicted)[p_indices].tolist()
            labels = np.array(labels)[p_indices].tolist()

            _, _, add_f1 = evaluation_utils.evaluate_instance_based(
                predicted, labels, empty_label=p0_index)
            val_f1 += add_f1

        val_f1 = val_f1 / (val_as_indices[0].shape[0] /
                           model_params['batch_size'])
        print("Validation f1: ", val_f1)

        val_results.append({'train_f1': train_f1, 'val_f1': val_f1})

        # save model
        if (train_epoch % 1 == 0 and save_model):
            torch.save(
                model.state_dict(),
                "{0}{1}-{2}.out".format(save_folder, model_name,
                                        str(train_epoch)))

        step = step + 1

        with open(val_results_file, 'w') as f:
            json.dump(val_results,
                      f,
                      indent=4,
                      cls=context_utils.CustomEncoder)