Exemplo n.º 1
0
def train(args, config, X_tra, X_val, input_dim=None, onehot_dim=None):
    if args.train_task == 'cancel':
        grader = Grader(X_val)
        model = CancelModel(args.can_model, config, args.filter_all,
                            args.use_onehot)
        if args.train_all:
            X_tra = X_tra + X_val

        model.train(X_tra)

        # cacenl error rate a.k.a CER
        cer = grader.eval_cancel_error_rate(model, IsCancelModel=True)
        return model, cer

    elif args.train_task == 'adr' or args.train_task == 'revenue':
        grader = Grader(X_val)
        model = ModelWrapper(args, config, args.use_onehot, args.filter_all,
                             input_dim, onehot_dim)

        if args.use_pretrain:
            pretrain_model = ModelWrapper(args, config, args.filter_all,
                                          args.use_onehot, input_dim,
                                          onehot_dim)
            pretrain_model.load('trained_models/pretrain.pkl')
            model.model.model = pretrain_model.model.model

        if args.train_all:
            X_tra = X_tra + X_val

        if args.verbose:
            model.train(X_tra, grader)
        else:
            model.train(X_tra)

        # revenue MAE a.k.a REV
        rev = grader.eval_revenue(model)
        mae = grader.eval_mae(model)
        return model, rev, mae

    elif args.train_task == 'label':
        grader = Grader(X_val)
        model = ModelWrapper(args, config, args.use_onehot, args.filter_all,
                             input_dim, onehot_dim)

        if args.verbose:
            model.train(X_tra, grader)
        else:
            model.train(X_tra)

        rev = grader.eval_revenue(model)
        mae = grader.eval_mae(model)
        return model, rev, mae
Exemplo n.º 2
0
fig = plt.figure()
cset = plt.scatter(output[:, 0], output[:, 1], s=8, c=bert_sim, cmap='PuBu')
plt.plot(output[opt['key_id'], 0], output[opt['key_id'], 1], 'r.')
plt.colorbar(cset)
plt.savefig(
    os.path.join(opt['model_save_dir'], 'plot', '%d_bert_sim' % opt['key_id']))

fig = plt.figure()
plt.hist(bert_sim, bins=100)
plt.savefig(
    os.path.join(opt['model_save_dir'], 'plot',
                 '%d_bert_hist' % opt['key_id']))

model = ModelWrapper(opt, weibo2embid, eva=True)
model.load(os.path.join(opt['model_save_dir'], 'best_model.pt'))

clash_delta = []
for i in range(len(emb_matrix)):
    emb1 = torch.tensor([i]).cuda()
    emb2 = torch.tensor([opt['key_id']]).cuda()
    clash_delta.append(model.model.get_delta(emb1, emb2).item())

ceil = max(max(clash_delta), min(clash_delta) * (-1)) * 1.05
print(max(clash_delta), min(clash_delta))
fig = plt.figure()
cset = plt.scatter(output[:, 0],
                   output[:, 1],
                   s=8,
                   c=clash_delta,
                   cmap='PuBu',
Exemplo n.º 3
0
 def __init__(self):
     self.mw = ModelWrapper.load()