Beispiel #1
0
def train(model_save_path, root_dir, day, train_log_dir, args, layer, dim):
    # print params
    tab_printer(args)
    # load train dataset
    dataset = DataSet(os.path.join(root_dir, day))  # TODO: 更改训练数据

    # create model on GPU:1
    running_context = torch.device("cuda:0")
    model = GCN(args, running_context, layer, dim).to(running_context)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    model.train()

    # train logs
    train_log_list = np.zeros((args.epochs, 24), dtype=np.float)

    print("\n" + "+" * 5 + " Train on day {} layer {} dim {} ".format(day, layer, dim) + "+" * 5)
    best_loss = float("inf")
    last_loss_decrease_epoch = 0
    stop_flag = False
    best_model = None
    train_start_time = time.perf_counter()
    for epoch in range(args.epochs):
        start_time = time.perf_counter()
        total_loss = 0.
        print("\n" + "+" * 10 + " epoch {:3d} ".format(epoch) + "+" * 10)
        for i in range(len(dataset)):  # 24 hours
            data = dataset[i]
            edge_index = data.edge_index.to(running_context)
            mask = data.mask.to(running_context)
            logits = model(data.inputs, edge_index)
            label = data.label.to(running_context, non_blocking=True)
            pos_cnt = torch.sum(label == 1)
            neg_cnt = torch.sum(label == 0)
            weight = torch.tensor([pos_cnt.float(), neg_cnt.float()]) / (neg_cnt + pos_cnt)
            loss_fn = nn.CrossEntropyLoss(weight=weight).to(running_context)
            loss = loss_fn(logits[mask], label)
            total_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            #             acc, tn, fp, fn, tp, white_p, white_r, white_incorrect, black_p, black_r, black_incorrect, macro_f1, micro_f1 \
            #                 = evaluate(logits[mask].max(1)[1], label)

            train_log_list[epoch][i] = float(loss.item())
            print("hour: {:2d}, loss: {:.4f}".format(i, loss.item()))

        if total_loss < best_loss:
            best_loss = total_loss
            last_loss_decrease_epoch = epoch
            best_model = model.state_dict()
        else:
            not_loss_decrease_epochs = epoch - last_loss_decrease_epoch
            if not_loss_decrease_epochs >= args.early_stop:
                stop_flag = True
            else:
                pass
        if stop_flag:
            print("early stop...")
            save_model(best_model, model_save_path)
        print("\nepoch: {:3d}, total_loss: {:.4f}, best_loss: {:.4f} time: {:.4f}" \
              .format(epoch + 1, total_loss, best_loss, time.perf_counter() - start_time))
    print("\ntotal train time: {}".format(time.perf_counter() - train_start_time))
    # save model when not early stop
    if not stop_flag:
        save_model(best_model, model_save_path)

    # save train logs to csv file
    print("Start to save train log of {}.".format(day))
    train_log_cols = ["hour_{}".format(hour) for hour in range(24)]
    train_log_df = pd.DataFrame(train_log_list, columns=train_log_cols)
    train_log_df.to_csv(train_log_dir, float_format="%.4f", index=None, columns=train_log_cols)
    print("Save train log of {} layer {} dim {} successfully.".format(day, layer, dim))
    torch.cuda.empty_cache()
Beispiel #2
0
def train(cmd_args):
    if not os.path.exists(cmd_args.exp_path):
        os.makedirs(cmd_args.exp_path)

    with open(joinpath(cmd_args.exp_path, 'options.txt'), 'w') as f:
        param_dict = vars(cmd_args)
        for param in param_dict:
            f.write(param + ' = ' + str(param_dict[param]) + '\n')

    logpath = joinpath(cmd_args.exp_path, 'eval.result')
    param_cnt_path = joinpath(cmd_args.exp_path, 'param_count.txt')

    # dataset and KG
    dataset = Dataset(cmd_args.data_root,
                      cmd_args.batchsize,
                      cmd_args.shuffle_sampling,
                      load_method=cmd_args.load_method)
    kg = KnowledgeGraph(dataset.fact_dict, PRED_DICT, dataset)

    # model
    if cmd_args.use_gcn == 1:
        gcn = GCN(kg,
                  cmd_args.embedding_size - cmd_args.gcn_free_size,
                  cmd_args.gcn_free_size,
                  num_hops=cmd_args.num_hops,
                  num_layers=cmd_args.num_mlp_layers,
                  transductive=cmd_args.trans == 1).to(cmd_args.device)
    else:
        gcn = TrainableEmbedding(kg,
                                 cmd_args.embedding_size).to(cmd_args.device)
    posterior_model = FactorizedPosterior(
        kg, cmd_args.embedding_size, cmd_args.slice_dim).to(cmd_args.device)
    mln = ConditionalMLN(cmd_args, dataset.rule_ls)

    if cmd_args.model_load_path is not None:
        gcn.load_state_dict(
            torch.load(joinpath(cmd_args.model_load_path, 'gcn.model')))
        posterior_model.load_state_dict(
            torch.load(joinpath(cmd_args.model_load_path, 'posterior.model')))

    # optimizers
    monitor = EarlyStopMonitor(cmd_args.patience)
    all_params = chain.from_iterable(
        [posterior_model.parameters(),
         gcn.parameters()])
    optimizer = optim.Adam(all_params,
                           lr=cmd_args.learning_rate,
                           weight_decay=cmd_args.l2_coef)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        'max',
        factor=cmd_args.lr_decay_factor,
        patience=cmd_args.lr_decay_patience,
        min_lr=cmd_args.lr_decay_min)

    with open(param_cnt_path, 'w') as f:
        cnt_gcn_params = count_parameters(gcn)
        cnt_posterior_params = count_parameters(posterior_model)
        if cmd_args.use_gcn == 1:
            f.write('GCN params count: %d\n' % cnt_gcn_params)
        elif cmd_args.use_gcn == 0:
            f.write('plain params count: %d\n' % cnt_gcn_params)
        f.write('posterior params count: %d\n' % cnt_posterior_params)
        f.write('Total params count: %d\n' %
                (cnt_gcn_params + cnt_posterior_params))

    if cmd_args.no_train == 1:
        cmd_args.num_epochs = 0

    # for Freebase data
    if cmd_args.load_method == 1:

        # prepare data for M-step
        tqdm.write('preparing data for M-step...')
        pred_arg1_set_arg2 = dict()
        pred_arg2_set_arg1 = dict()
        pred_fact_set = dict()
        for pred in dataset.fact_dict_2:
            pred_arg1_set_arg2[pred] = dict()
            pred_arg2_set_arg1[pred] = dict()
            pred_fact_set[pred] = set()
            for _, args in dataset.fact_dict_2[pred]:
                if args[0] not in pred_arg1_set_arg2[pred]:
                    pred_arg1_set_arg2[pred][args[0]] = set()
                if args[1] not in pred_arg2_set_arg1[pred]:
                    pred_arg2_set_arg1[pred][args[1]] = set()
                pred_arg1_set_arg2[pred][args[0]].add(args[1])
                pred_arg2_set_arg1[pred][args[1]].add(args[0])
                pred_fact_set[pred].add(args)

        grounded_rules = []
        for rule_idx, rule in enumerate(dataset.rule_ls):
            grounded_rules.append(set())
            body_atoms = []
            head_atom = None
            for atom in rule.atom_ls:
                if atom.neg:
                    body_atoms.append(atom)
                elif head_atom is None:
                    head_atom = atom
            # atom in body must be observed
            assert len(body_atoms) <= 2
            if len(body_atoms) > 0:
                body1 = body_atoms[0]
                for _, body1_args in dataset.fact_dict_2[body1.pred_name]:
                    var2arg = dict()
                    var2arg[body1.var_name_ls[0]] = body1_args[0]
                    var2arg[body1.var_name_ls[1]] = body1_args[1]
                    for body2 in body_atoms[1:]:
                        if body2.var_name_ls[0] in var2arg:
                            if var2arg[body2.var_name_ls[
                                    0]] in pred_arg1_set_arg2[body2.pred_name]:
                                for body2_arg2 in pred_arg1_set_arg2[
                                        body2.pred_name][var2arg[
                                            body2.var_name_ls[0]]]:
                                    var2arg[body2.var_name_ls[1]] = body2_arg2
                                    grounded_rules[rule_idx].add(
                                        tuple(sorted(var2arg.items())))
                        elif body2.var_name_ls[1] in var2arg:
                            if var2arg[body2.var_name_ls[
                                    1]] in pred_arg2_set_arg1[body2.pred_name]:
                                for body2_arg1 in pred_arg2_set_arg1[
                                        body2.pred_name][var2arg[
                                            body2.var_name_ls[1]]]:
                                    var2arg[body2.var_name_ls[0]] = body2_arg1
                                    grounded_rules[rule_idx].add(
                                        tuple(sorted(var2arg.items())))

        # Collect head atoms derived by grounded formulas
        grounded_obs = dict()
        grounded_hid = dict()
        grounded_hid_score = dict()
        cnt_hid = 0
        for rule_idx in range(len(dataset.rule_ls)):
            rule = dataset.rule_ls[rule_idx]
            for var2arg in grounded_rules[rule_idx]:
                var2arg = dict(var2arg)
                head_atom = rule.atom_ls[-1]
                assert not head_atom.neg  # head atom
                pred = head_atom.pred_name
                args = (var2arg[head_atom.var_name_ls[0]],
                        var2arg[head_atom.var_name_ls[1]])
                if args in pred_fact_set[pred]:
                    if (pred, args) not in grounded_obs:
                        grounded_obs[(pred, args)] = []
                    grounded_obs[(pred, args)].append(rule_idx)
                else:
                    if (pred, args) not in grounded_hid:
                        grounded_hid[(pred, args)] = []
                    grounded_hid[(pred, args)].append(rule_idx)
        tqdm.write('observed: %d, hidden: %d' %
                   (len(grounded_obs), len(grounded_hid)))

        # Aggregate atoms by predicates for fast inference
        pred_aggregated_hid = dict()
        pred_aggregated_hid_args = dict()
        for (pred, args) in grounded_hid:
            if pred not in pred_aggregated_hid:
                pred_aggregated_hid[pred] = []
            if pred not in pred_aggregated_hid_args:
                pred_aggregated_hid_args[pred] = []
            pred_aggregated_hid[pred].append(
                (dataset.const2ind[args[0]], dataset.const2ind[args[1]]))
            pred_aggregated_hid_args[pred].append(args)
        pred_aggregated_hid_list = [[
            pred, pred_aggregated_hid[pred]
        ] for pred in sorted(pred_aggregated_hid.keys())]

        for current_epoch in range(cmd_args.num_epochs):

            # E-step: optimize the parameters in the posterior model
            num_batches = int(
                math.ceil(len(dataset.test_fact_ls) / cmd_args.batchsize))

            pbar = tqdm(total=num_batches)
            acc_loss = 0.0
            cur_batch = 0

            for samples_by_r, latent_mask_by_r, neg_mask_by_r, obs_var_by_r, neg_var_by_r in \
                dataset.get_batch_by_q(cmd_args.batchsize):

                node_embeds = gcn(dataset)

                loss = 0.0
                r_cnt = 0
                for ind, samples in enumerate(samples_by_r):
                    neg_mask = neg_mask_by_r[ind]
                    latent_mask = latent_mask_by_r[ind]
                    obs_var = obs_var_by_r[ind]
                    neg_var = neg_var_by_r[ind]

                    if sum([len(e[1]) for e in neg_mask]) == 0:
                        continue

                    potential, posterior_prob, obs_xent = posterior_model(
                        [samples, neg_mask, latent_mask, obs_var, neg_var],
                        node_embeds,
                        fast_mode=True)

                    if cmd_args.no_entropy == 1:
                        entropy = 0
                    else:
                        entropy = compute_entropy(
                            posterior_prob) / cmd_args.entropy_temp

                    loss += -(potential.sum() * dataset.rule_ls[ind].weight +
                              entropy) / (potential.size(0) + 1e-6) + obs_xent

                    r_cnt += 1

                if r_cnt > 0:
                    loss /= r_cnt
                    acc_loss += loss.item()

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                pbar.update()
                cur_batch += 1
                pbar.set_description(
                    'Epoch %d, train loss: %.4f, lr: %.4g' %
                    (current_epoch, acc_loss / cur_batch, get_lr(optimizer)))

            # M-step: optimize the weights of logic rules
            with torch.no_grad():
                posterior_prob = posterior_model(pred_aggregated_hid_list,
                                                 node_embeds,
                                                 fast_inference_mode=True)
                for pred_i, (pred,
                             var_ls) in enumerate(pred_aggregated_hid_list):
                    for var_i, var in enumerate(var_ls):
                        args = pred_aggregated_hid_args[pred][var_i]
                        grounded_hid_score[(
                            pred, args)] = posterior_prob[pred_i][var_i]

                rule_weight_gradient = torch.zeros(len(dataset.rule_ls))
                for (pred, args) in grounded_obs:
                    for rule_idx in set(grounded_obs[(pred, args)]):
                        rule_weight_gradient[
                            rule_idx] += 1.0 - compute_MB_proba(
                                dataset.rule_ls, grounded_obs[(pred, args)])
                for (pred, args) in grounded_hid:
                    for rule_idx in set(grounded_hid[(pred, args)]):
                        target = grounded_hid_score[(pred, args)]
                        rule_weight_gradient[
                            rule_idx] += target - compute_MB_proba(
                                dataset.rule_ls, grounded_hid[(pred, args)])

                for rule_idx, rule in enumerate(dataset.rule_ls):
                    rule.weight += cmd_args.learning_rate_rule_weights * rule_weight_gradient[
                        rule_idx]
                    print(dataset.rule_ls[rule_idx].weight, end=' ')

            pbar.close()

            # validation
            with torch.no_grad():
                node_embeds = gcn(dataset)

                valid_loss = 0.0
                cnt_batch = 0
                for samples_by_r, latent_mask_by_r, neg_mask_by_r, obs_var_by_r, neg_var_by_r in \
                    dataset.get_batch_by_q(cmd_args.batchsize, validation=True):
                    loss = 0.0
                    r_cnt = 0
                    for ind, samples in enumerate(samples_by_r):
                        neg_mask = neg_mask_by_r[ind]
                        latent_mask = latent_mask_by_r[ind]
                        obs_var = obs_var_by_r[ind]
                        neg_var = neg_var_by_r[ind]

                        if sum([len(e[1]) for e in neg_mask]) == 0:
                            continue

                        valid_potential, valid_prob, valid_obs_xent = posterior_model(
                            [samples, neg_mask, latent_mask, obs_var, neg_var],
                            node_embeds,
                            fast_mode=True)

                        if cmd_args.no_entropy == 1:
                            valid_entropy = 0
                        else:
                            valid_entropy = compute_entropy(
                                valid_prob) / cmd_args.entropy_temp

                        loss += -(valid_potential.sum() + valid_entropy) / (
                            valid_potential.size(0) + 1e-6) + valid_obs_xent

                        r_cnt += 1

                    if r_cnt > 0:
                        loss /= r_cnt
                        valid_loss += loss.item()

                    cnt_batch += 1

                tqdm.write('Epoch %d, valid loss: %.4f' %
                           (current_epoch, valid_loss / cnt_batch))

                should_stop = monitor.update(valid_loss)
                scheduler.step(valid_loss)

                is_current_best = monitor.cnt == 0
                if is_current_best:
                    savepath = joinpath(cmd_args.exp_path, 'saved_model')
                    os.makedirs(savepath, exist_ok=True)
                    torch.save(gcn.state_dict(),
                               joinpath(savepath, 'gcn.model'))
                    torch.save(posterior_model.state_dict(),
                               joinpath(savepath, 'posterior.model'))

                should_stop = should_stop or (current_epoch + 1
                                              == cmd_args.num_epochs)

                if should_stop:
                    tqdm.write('Early stopping')
                    break

        # ======================= generate rank list =======================
        node_embeds = gcn(dataset)

        pbar = tqdm(total=len(dataset.test_fact_ls))
        pbar.write('*' * 10 + ' Evaluation ' + '*' * 10)
        rrank = 0.0
        hits = 0.0
        cnt = 0

        rrank_pred = dict([(pred_name, 0.0) for pred_name in PRED_DICT])
        hits_pred = dict([(pred_name, 0.0) for pred_name in PRED_DICT])
        cnt_pred = dict([(pred_name, 0.0) for pred_name in PRED_DICT])

        for pred_name, X, invX, sample in gen_eval_query(dataset,
                                                         const2ind=kg.ent2idx):
            x_mat = np.array(X)
            invx_mat = np.array(invX)
            sample_mat = np.array(sample)

            tail_score, head_score, true_score = posterior_model(
                [pred_name, x_mat, invx_mat, sample_mat],
                node_embeds,
                batch_mode=True)

            rank = torch.sum(tail_score >= true_score).item() + 1
            rrank += 1.0 / rank
            hits += 1 if rank <= 10 else 0

            rrank_pred[pred_name] += 1.0 / rank
            hits_pred[pred_name] += 1 if rank <= 10 else 0

            rank = torch.sum(head_score >= true_score).item() + 1
            rrank += 1.0 / rank
            hits += 1 if rank <= 10 else 0

            rrank_pred[pred_name] += 1.0 / rank
            hits_pred[pred_name] += 1 if rank <= 10 else 0

            cnt_pred[pred_name] += 2
            cnt += 2

            if cnt % 100 == 0:
                with open(logpath, 'w') as f:
                    f.write('%i sample eval\n' % cnt)
                    f.write('mmr %.4f\n' % (rrank / cnt))
                    f.write('hits %.4f\n' % (hits / cnt))

                    f.write('\n')
                    for pred_name in PRED_DICT:
                        if cnt_pred[pred_name] == 0:
                            continue
                        f.write('mmr %s %.4f\n' %
                                (pred_name,
                                 rrank_pred[pred_name] / cnt_pred[pred_name]))
                        f.write('hits %s %.4f\n' %
                                (pred_name,
                                 hits_pred[pred_name] / cnt_pred[pred_name]))

            pbar.update()

        with open(logpath, 'w') as f:
            f.write('complete\n')
            f.write('mmr %.4f\n' % (rrank / cnt))
            f.write('hits %.4f\n' % (hits / cnt))
            f.write('\n')

            tqdm.write('mmr %.4f\n' % (rrank / cnt))
            tqdm.write('hits %.4f\n' % (hits / cnt))

            for pred_name in PRED_DICT:
                if cnt_pred[pred_name] == 0:
                    continue
                f.write(
                    'mmr %s %.4f\n' %
                    (pred_name, rrank_pred[pred_name] / cnt_pred[pred_name]))
                f.write(
                    'hits %s %.4f\n' %
                    (pred_name, hits_pred[pred_name] / cnt_pred[pred_name]))

        os.system(
            'mv %s %s' %
            (logpath,
             joinpath(
                 cmd_args.exp_path, 'performance_hits_%.4f_mmr_%.4f.txt' %
                 ((hits / cnt), (rrank / cnt)))))
        pbar.close()

    # for Kinship / UW-CSE / Cora data
    elif cmd_args.load_method == 0:
        for current_epoch in range(cmd_args.num_epochs):
            pbar = tqdm(range(cmd_args.num_batches))
            acc_loss = 0.0

            for k in pbar:
                node_embeds = gcn(dataset)

                batch_neg_mask, flat_list, batch_latent_var_inds, observed_rule_cnts, batch_observed_vars = dataset.get_batch_rnd(
                    observed_prob=cmd_args.observed_prob,
                    filter_latent=cmd_args.filter_latent == 1,
                    closed_world=cmd_args.closed_world == 1,
                    filter_observed=1)

                posterior_prob = posterior_model(flat_list, node_embeds)

                if cmd_args.no_entropy == 1:
                    entropy = 0
                else:
                    entropy = compute_entropy(
                        posterior_prob) / cmd_args.entropy_temp

                entropy = entropy.to('cpu')
                posterior_prob = posterior_prob.to('cpu')

                potential = mln(batch_neg_mask, batch_latent_var_inds,
                                observed_rule_cnts, posterior_prob, flat_list,
                                batch_observed_vars)

                optimizer.zero_grad()

                loss = -(potential + entropy) / cmd_args.batchsize
                acc_loss += loss.item()

                loss.backward()

                optimizer.step()

                pbar.set_description('train loss: %.4f, lr: %.4g' %
                                     (acc_loss / (k + 1), get_lr(optimizer)))

            # test
            node_embeds = gcn(dataset)
            with torch.no_grad():

                posterior_prob = posterior_model(
                    [(e[1], e[2]) for e in dataset.test_fact_ls], node_embeds)
                posterior_prob = posterior_prob.to('cpu')

                label = np.array([e[0] for e in dataset.test_fact_ls])
                test_log_prob = float(
                    np.sum(
                        np.log(
                            np.clip(
                                np.abs((1 - label) - posterior_prob.numpy()),
                                1e-6, 1 - 1e-6))))

                auc_roc = roc_auc_score(label, posterior_prob.numpy())
                auc_pr = average_precision_score(label, posterior_prob.numpy())

                tqdm.write(
                    'Epoch: %d, train loss: %.4f, test auc-roc: %.4f, test auc-pr: %.4f, test log prob: %.4f'
                    % (current_epoch, acc_loss / cmd_args.num_batches, auc_roc,
                       auc_pr, test_log_prob))
                # tqdm.write(str(posterior_prob[:10]))

            # validation for early stop
            valid_sample = []
            valid_label = []
            for pred_name in dataset.valid_dict_2:
                for val, consts in dataset.valid_dict_2[pred_name]:
                    valid_sample.append((pred_name, consts))
                    valid_label.append(val)
            valid_label = np.array(valid_label)

            valid_prob = posterior_model(valid_sample, node_embeds)
            valid_prob = valid_prob.to('cpu')

            valid_log_prob = float(
                np.sum(
                    np.log(
                        np.clip(np.abs((1 - valid_label) - valid_prob.numpy()),
                                1e-6, 1 - 1e-6))))

            # tqdm.write('epoch: %d, valid log prob: %.4f' % (current_epoch, valid_log_prob))
            #
            # should_stop = monitor.update(-valid_log_prob)
            # scheduler.step(valid_log_prob)
            #
            # is_current_best = monitor.cnt == 0
            # if is_current_best:
            #   savepath = joinpath(cmd_args.exp_path, 'saved_model')
            #   os.makedirs(savepath, exist_ok=True)
            #   torch.save(gcn.state_dict(), joinpath(savepath, 'gcn.model'))
            #   torch.save(posterior_model.state_dict(), joinpath(savepath, 'posterior.model'))
            #
            # should_stop = should_stop or (current_epoch + 1 == cmd_args.num_epochs)
            #
            # if should_stop:
            #   tqdm.write('Early stopping')
            #   break

        # evaluation after training
        node_embeds = gcn(dataset)
        with torch.no_grad():
            posterior_prob = posterior_model([(e[1], e[2])
                                              for e in dataset.test_fact_ls],
                                             node_embeds)
            posterior_prob = posterior_prob.to('cpu')

            label = np.array([e[0] for e in dataset.test_fact_ls])
            test_log_prob = float(
                np.sum(
                    np.log(
                        np.clip(np.abs((1 - label) - posterior_prob.numpy()),
                                1e-6, 1 - 1e-6))))

            auc_roc = roc_auc_score(label, posterior_prob.numpy())
            auc_pr = average_precision_score(label, posterior_prob.numpy())

            tqdm.write(
                'test auc-roc: %.4f, test auc-pr: %.4f, test log prob: %.4f' %
                (auc_roc, auc_pr, test_log_prob))