Exemple #1
0
def go(arg):

    global repeats
    repeats = arg.repeats

    tbdir = arg.tb_dir if arg.tb_dir is not None else os.path.join(
        './runs', get_slug(arg))[:250]
    tbw = SummaryWriter(log_dir=tbdir)

    dev = 'cuda' if torch.cuda.is_available() else 'cpu'

    train_mrrs = []
    test_mrrs = []

    train, val, test, (n2i, i2n), (r2i, i2r) = \
        kgmodels.load_lp(arg.name)

    print(len(i2n), 'nodes')
    print(len(i2r), 'relations')
    print(train.size(0), 'training triples')
    print(test.size(0), 'test triples')
    print(train.size(0) + test.size(0), 'total triples')

    # print(train)
    # print(test)
    # sys.exit()

    # set of all triples (for filtering)
    alltriples = set()
    for s, p, o in torch.cat([train, test], dim=0):
        s, p, o = s.item(), p.item(), o.item()

        alltriples.add((s, p, o))

    if arg.final:
        train, test = torch.cat([train, val], dim=0), test
    else:
        train, test = train, val

    if arg.decomp == 'block':
        # -- pad the node list to make it divisible by the nr. of blocks

        added = 0
        while len(i2n) % arg.num_blocks != 0:
            label = 'null' + str(added)
            i2n.append(label)
            n2i[label] = len(i2n) - 1

            added += 1

        print(
            f'nodes padded to {len(i2n)} to make it divisible by {arg.num_blocks} (added {added} null nodes).'
        )

    if repeats > 1:
        RP, EP = trange, range
    else:
        RP, EP = range, trange

    for r in RP(repeats):
        """
        Define model
        """
        if arg.model == 'classic':
            model = kgmodels.LinkPrediction(triples=train,
                                            n=len(i2n),
                                            r=len(i2r),
                                            hidden=arg.emb,
                                            out=arg.emb,
                                            decomp=arg.decomp,
                                            numbases=arg.num_bases,
                                            numblocks=arg.num_blocks,
                                            depth=arg.depth,
                                            do=arg.do,
                                            biases=arg.biases,
                                            prune=arg.prune,
                                            dropout=arg.edge_dropout)
        elif arg.model == 'narrow':
            model = kgmodels.LPNarrow(triples=train,
                                      n=len(i2n),
                                      r=len(i2r),
                                      emb=arg.emb,
                                      hidden=arg.hidden,
                                      decomp=arg.decomp,
                                      numbases=arg.num_bases,
                                      numblocks=arg.num_blocks,
                                      depth=arg.depth,
                                      do=arg.do,
                                      biases=arg.biases,
                                      prune=arg.prune,
                                      edge_dropout=arg.edge_dropout)
        elif arg.model == 'sampling':
            model = kgmodels.SimpleLP(triples=train,
                                      n=len(i2n),
                                      r=len(i2r),
                                      emb=arg.emb,
                                      h=arg.hidden,
                                      ksample=arg.k,
                                      csample=arg.c,
                                      multi=arg.multi,
                                      decoder=arg.decoder)
        else:
            raise Exception(f'model not recognized: {arg.model}')

        if torch.cuda.is_available():
            prt('Using CUDA.')
            model.cuda()

        if arg.opt == 'adam':
            opt = torch.optim.Adam(model.parameters(), lr=arg.lr[0])
        elif arg.opt == 'adamw':
            opt = torch.optim.AdamW(model.parameters(), lr=arg.lr[0])
        elif arg.opt == 'adagrad':
            opt = torch.optim.Adagrad(model.parameters(), lr=arg.lr[0])
        elif arg.opt == 'sgd':
            opt = torch.optim.SGD(model.parameters(),
                                  lr=arg.lr[0],
                                  nesterov=True,
                                  momentum=arg.momentum)
        else:
            raise Exception()

        # nr of negatives sampled
        ng = arg.negative_rate

        seen = 0
        for e in range(sum(arg.epochs)):

            depth = 0
            set_lr(opt, arg.lr[0])
            if e >= arg.epochs[0]:
                depth = 1
                set_lr(opt, arg.lr[1])
            if e >= sum(arg.epochs[:2]):
                depth = 2
                set_lr(opt, arg.lr[2])

            seeni, sumloss = 0, 0.0

            if arg.c is not None:
                tic()
                model.precompute_globals()
                print(f'precomp took {toc():.2}s')

            tsample, tforward, tbackward, ttotal, tloss, tstep = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
            for fr in EP(0, train.size(0), arg.batch):

                tic()
                model.train(True)

                if arg.limit is not None and seeni > arg.limit:
                    break

                #
                # if torch.cuda.is_available() and random.random() < 0.01:
                #     print(f'\nPeak gpu memory use is {torch.cuda.max_memory_cached() / 1e9:.2} Gb')

                to = min(train.size(0), fr + arg.batch)

                with torch.no_grad():
                    positives = train[fr:to]

                    b, _ = positives.size()

                    tic()

                    # sample negatives
                    if arg.corrupt_global:  # global corruption (sample random true triples to corrupt)
                        indices = torch.randint(size=(b * ng, ),
                                                low=0,
                                                high=train.size(0))
                        negatives = train[indices, :].view(
                            b, ng, 3)  # -- triples to be corrupted

                    else:  # local corruption (directly corrupt the current batch)
                        negatives = positives.clone()[:, None, :].expand(
                            b, ng, 3).contiguous()

                    corrupt(negatives, len(i2n))

                    triples = torch.cat([positives[:, None, :], negatives],
                                        dim=1)

                    if torch.cuda.is_available():
                        triples = triples.cuda()

                    if arg.loss == 'bce':
                        labels = torch.cat(
                            [torch.ones(b, 1),
                             torch.zeros(b, ng)], dim=1)
                    elif arg.loss == 'ce':
                        labels = torch.zeros(b, dtype=torch.long)
                        # -- CE loss treats the problem as a multiclass classification problem: for a positive triple,
                        #    together with its k corruptions, identify which is the true triple. This is always triple 0,
                        #    but the score function is order equivariant, so i can't see the index of the triple it's
                        #    classifying.

                    if torch.cuda.is_available():
                        labels = labels.cuda()

                tsample += toc()

                opt.zero_grad()

                tic()
                out = model(triples, depth=depth)

                assert out.size() == (b, ng + 1)

                tic()
                if arg.loss == 'bce':
                    loss = F.binary_cross_entropy_with_logits(out, labels)
                elif arg.loss == 'ce':
                    loss = F.cross_entropy(out, labels)

                if arg.l2weight is not None:
                    l2 = sum([p.pow(2).sum() for p in model.parameters()])
                    loss = loss + arg.l2weight * l2

                tloss += toc()

                tforward += toc()

                tic()
                loss.backward()
                tbackward += toc()

                sumloss += float(loss.item())

                tic()
                opt.step()
                tstep += toc()

                seen += b
                seeni += b
                ttotal += toc()

            prt(f'epoch {e} (d{depth}); training loss {sumloss/seeni:.4}       s {tsample:.3}s, f {tforward:.3}s (loss {tloss:.3}s), b {tbackward:.3}, st {tstep:.3}, t {ttotal:.3}s'
                )

            # Evaluate
            if (e % arg.eval_int == 0 and e != 0) or e == sum(arg.epochs) - 1:
                with torch.no_grad():

                    model.train(False)

                    ranks = []

                    mrr = hitsat1 = hitsat3 = hitsat10 = 0.0

                    if arg.eval_size is None:
                        testsub = test
                    else:
                        testsub = test[random.sample(range(test.size(0)),
                                                     k=arg.eval_size)]

                    tseen = 0
                    for tail in [True, False]:  # head or tail prediction

                        for s, p, o in (testsub if repeats > 1 else
                                        tqdm.tqdm(testsub)):

                            s, p, o = s.item(), p.item(), o.item()

                            if tail:
                                ot = o
                                del o

                                raw_candidates = [(s, p, o)
                                                  for o in range(len(i2n))]
                                candidates = filter(raw_candidates, alltriples,
                                                    (s, p, ot))

                            else:
                                st = s
                                del s

                                raw_candidates = [(s, p, o)
                                                  for s in range(len(i2n))]
                                candidates = filter(raw_candidates, alltriples,
                                                    (st, p, o))

                            candidates = torch.tensor(candidates)
                            scores = util.batch(model,
                                                candidates,
                                                batch_size=arg.batch * 2,
                                                depth=depth)
                            # -- the batch size needs to be a little conservative here, due to the high variance in nr of
                            #    triples sampled.

                            sorted_candidates = [
                                tuple(p[0])
                                for p in sorted(zip(candidates.tolist(),
                                                    scores.tolist()),
                                                key=lambda p: -p[1])
                            ]

                            rank = (sorted_candidates.index((s, p, ot)) +
                                    1) if tail else (sorted_candidates.index(
                                        (st, p, o)) + 1)
                            ranks.append(rank)

                            hitsat1 += (rank == 1)
                            hitsat3 += (rank <= 3)
                            hitsat10 += (rank <= 10)
                            mrr += 1.0 / rank

                            tseen += 1

                    mrr = mrr / tseen
                    hitsat1 = hitsat1 / tseen
                    hitsat3 = hitsat3 / tseen
                    hitsat10 = hitsat10 / tseen

                    prt(f'epoch {e}: MRR {mrr:.4}\t hits@1 {hitsat1:.4}\t  hits@3 {hitsat3:.4}\t  hits@10 {hitsat10:.4}'
                        )
                    prt(f'   ranks : {ranks[:10]}')

        test_mrrs.append(mrr)

    print('training finished.')

    temrrs = torch.tensor(test_mrrs)
    print(
        f'mean test MRR    {temrrs.mean():.3} ({temrrs.std():.3})  \t{test_mrrs}'
    )
def extract_headline(doc,url):

    logging.debug("extracting headline")

    candidates = {}

    for h in util.tags(doc,'h1','h2','h3','h4','h5','h6','div','span','td','th'):
        score = 1
        txt = unicode(h.text_content()).strip()
        txt = u' '.join(txt.split())
        if len(txt)==0:
            continue


        txt_norm = util.normalise_text(txt)

        if len(txt)>=500:
            continue

        logging.debug(" headline: consider %s '%s'" % (h.tag,txt,))

        # TODO: should run all these tests over a real corpus of articles
        # and work out proper probability-based scoring!

        # TEST: length of headline
        # TODO - get a proper headline-length frequency curve from
        # journalisted and score according to probability
        if len(txt)>=20 and len(txt)<60:
            logging.debug("  len in [20,60)")
            score +=1
        elif len(txt)>=25 and len(txt)<40:
            logging.debug("  len in [25,40)")
            score += 2

        if h.tag in ('h1','h2','h3','h4'):
            logging.debug("  significant heading (%s)" % (h.tag,))
            score += 2
        if h.tag in ('span','td'):
            logging.debug("  -2 less headline-y element (%s)" % (h.tag,))
            score -= 2

        # TEST: does it appear in <title> text?
        title = unicode(getattr(doc.find('.//title'), 'text', ''))
        if title is not None:
            if txt_norm in util.normalise_text(title):
                logging.debug("  appears in <title>")
                score += 3

        # TEST: likely-looking class or id
        if pats.headline['classes'].search(h.get('class','')):
            logging.debug("  likely class")
            score += 2
        if pats.headline['classes'].search(h.get('id','')):
            logging.debug("  likely id")
            score += 2


        # TEST: does it appear in likely looking <meta> tags?
        # eg:
        # <meta property="og:title" content="Dementia checks at age 75 urged"/>
        # <meta name="Headline" content="Dementia checks at age 75 urged"/>
        for meta in doc.findall('.//meta'):
            n = meta.get('name', meta.get('property', ''))
            if pats.headline['metatags'].search(n):
                meta_content = util.normalise_text(unicode(meta.get('content','')))
                if meta_content != '':
                    if txt_norm==meta_content:
                        logging.debug("  match meta")
                        score += 3
                    elif txt_norm in meta_content:
                        logging.debug("  contained by meta")
                        score += 1

        # TEST: does it match slug part of url?
        slug = re.split('[-_]', util.get_slug(url).lower())
        parts = [util.normalise_text(part) for part in txt.split()]
        parts = [part for part in parts if part!='']
        if len(parts) > 1:
            matched = [part for part in parts if part in slug]

            value = (5.0*len(matched) / len(parts)) # max 5 points
            if value > 0:
                logging.debug("  match slug (%01f)" % (value,))
                score += value

        # TODO: other possible tests
        # TEST: is it near top of article container?
        # TEST: is it just above article container?
        # TEST: is it non-complex html (anything more complex than <a>)
        # TEST: is it outside likely sidebar elements?

        if txt not in candidates or score > candidates[txt]:
            candidates[txt] = {'txt':txt, 'score':score, 'sourceline':h.sourceline, 'node':h}

    if not candidates:
        return None

    # sort
    out = sorted(candidates.items(), key=lambda item: item[1]['score'], reverse=True)

    #pprint(out[:5])
    return out[0][1]
Exemple #3
0
def go(arg):

    global repeats
    repeats = arg.repeats

    tbdir = arg.tb_dir if arg.tb_dir is not None else os.path.join('./runs', get_slug(arg))[:250]
    tbw = SummaryWriter(log_dir=tbdir)

    dev = 'cuda' if torch.cuda.is_available() else 'cpu'

    test_mrrs = []

    train, val, test, (n2i, i2n), (r2i, i2r) = \
        embed.load(arg.name)

    # set of all triples (for filtering)
    alltriples = set()
    for s, p, o in torch.cat([train, val, test], dim=0):
        s, p, o = s.item(), p.item(), o.item()

        alltriples.add((s, p, o))

    truedicts = util.truedicts(alltriples)

    if arg.final:
        train, test = torch.cat([train, val], dim=0), test
    else:
        train, test = train, val

    subjects   = torch.tensor(list({s for s, _, _ in train}), dtype=torch.long, device=d())
    predicates = torch.tensor(list({p for _, p, _ in train}), dtype=torch.long, device=d())
    objects    = torch.tensor(list({o for _, _, o in train}), dtype=torch.long, device=d())
    ccandidates = (subjects, predicates, objects)

    print(len(i2n), 'nodes')
    print(len(i2r), 'relations')
    print(train.size(0), 'training triples')
    print(test.size(0), 'test triples')
    print(train.size(0) + test.size(0), 'total triples')

    for r in tqdm.trange(repeats) if repeats > 1 else range(repeats):

        """
        Define model
        """
        model = embed.LinkPredictor(
            triples=train, n=len(i2n), r=len(i2r), embedding=arg.emb, biases=arg.biases,
            edropout = arg.edo, rdropout=arg.rdo, decoder=arg.decoder, reciprocal=arg.reciprocal,
            init_method=arg.init_method, init_parms=arg.init_parms)

        if torch.cuda.is_available():
            prt('Using CUDA.')
            model.cuda()

        if arg.opt == 'adam':
            opt = torch.optim.Adam(model.parameters(), lr=arg.lr)
        elif arg.opt == 'adamw':
            opt = torch.optim.AdamW(model.parameters(), lr=arg.lr)
        elif arg.opt == 'adagrad':
            opt = torch.optim.Adagrad(model.parameters(), lr=arg.lr)
        elif arg.opt == 'sgd':
            opt = torch.optim.SGD(model.parameters(), lr=arg.lr, nesterov=True, momentum=arg.momentum)
        else:
            raise Exception()

        sched = torch.optim.lr_scheduler.ReduceLROnPlateau(patience=arg.patience, optimizer=opt, mode='max', factor=0.95, threshold=0.0001) \
            if arg.sched else None
        #-- defaults taken from libkge

        # nr of negatives sampled
        weight = torch.tensor([arg.nweight, 1.0], device=d()) if arg.nweight else None

        seen = 0
        for e in range(arg.epochs):

            seeni, sumloss = 0, 0.0
            tforward = tbackward = 0
            rforward = rbackward = 0
            tprep = tloss = 0
            tic()

            for fr in trange(0, train.size(0), arg.batch):
                to = min(train.size(0), fr + arg.batch)

                model.train(True)

                opt.zero_grad()

                positives = train[fr:to].to(d())

                for ctarget in [0, 1, 2]: # which part of the triple to corrupt
                    ng = arg.negative_rate[ctarget]

                    if ng > 0:

                        with torch.no_grad():
                            bs, _ = positives.size()

                            tic()
                            if arg.limit_negatives:
                                cand = ccandidates[ctarget]
                                mx = cand.size(0)
                                idx = torch.empty(bs, ng, dtype=torch.long, device=d()).random_(0, mx)
                                corruptions = cand[idx]
                            else:
                                mx = len(i2r) if ctarget == 1 else len(i2n)
                                corruptions = torch.empty(bs, ng, dtype=torch.long, device=d()).random_(0, mx)
                            tprep += toc()

                            s, p, o = positives[:, 0:1], positives[:, 1:2], positives[:, 2:3]
                            if ctarget == 0:
                                s = torch.cat([s, corruptions], dim=1)
                            if ctarget == 1:
                                p = torch.cat([p, corruptions], dim=1)
                            if ctarget == 2:
                                o = torch.cat([o, corruptions], dim=1)

                            # -- NB: two of the index vectors s, p o are now size (bs, 1) and the other is (bs, ng+1)
                            #    We will let the model broadcast these to give us a score tensor of (bs, ng+1)
                            #    In most cases we can optimize the decoder to broadcast late for better speed.

                            if arg.loss == 'bce':
                                labels = torch.cat([torch.ones(bs, 1, device=d()), torch.zeros(bs, ng, device=d())], dim=1)
                            elif arg.loss == 'ce':
                                labels = torch.zeros(bs, dtype=torch.long, device=d())
                                # -- CE loss treats the problem as a multiclass classification problem: for a positive triple,
                                #    together with its k corruptions, identify which is the true triple. This is always triple 0.
                                #    (It may seem like the model could easily cheat by always choosing triple 0, but the score
                                #    function is order equivariant, so it can't choose by ordering.)

                        recip = None if not arg.reciprocal else ('head' if ctarget == 0 else 'tail')
                        # -- We use the tail relations if the target is the relation (usually p-corruption is not used)

                        tic()
                        out = model(s, p, o, recip=recip)
                        tforward += toc()

                        assert out.size() == (bs, ng + 1), f'{out.size()=} {(bs, ng + 1)=}'

                        tic()
                        if arg.loss == 'bce':
                            loss = F.binary_cross_entropy_with_logits(out, labels, weight=weight, reduction=arg.lred)
                        elif arg.loss == 'ce':
                            loss = F.cross_entropy(out, labels, reduction=arg.lred)

                        assert not torch.isnan(loss), 'Loss has become NaN'

                        sumloss += float(loss.item())
                        seen += bs; seeni += bs
                        tloss += toc()

                        tic()
                        loss.backward()
                        tbackward += toc()
                        # No step yet, we accumulate the gradients over all corruptions.
                        # -- this causes problems with modules like batchnorm, so be careful when porting.

                tic()
                regloss = None
                if arg.reg_eweight is not None:
                    regloss = model.penalty(which='entities', p=arg.reg_exp, rweight=arg.reg_eweight)

                if arg.reg_rweight is not None:
                    regloss = model.penalty(which='relations', p=arg.reg_exp, rweight=arg.reg_rweight)
                rforward += toc()

                tic()
                if regloss is not None:
                    sumloss += float(regloss.item())
                    regloss.backward()
                rbackward += toc()

                opt.step()

                tbw.add_scalar('biases/train_loss', float(loss.item()), seen)

            if e == 0:
                print(f'\n pred: forward {tforward:.4}, backward {tbackward:.4}')
                print (f'   reg: forward {rforward:.4}, backward {rbackward:.4}')
                print (f'           prep {tprep:.4}, loss {tloss:.4}')
                print (f' total: {toc():.4}')
                # -- NB: these numbers will not be accurate for GPU runs unless CUDA_LAUNCH_BLOCKING is set to 1

            # Evaluate
            if ((e+1) % arg.eval_int == 0) or e == arg.epochs - 1:

                with torch.no_grad():

                    model.train(False)

                    if arg.eval_size is None:
                        testsub = test
                    else:
                        testsub = test[random.sample(range(test.size(0)), k=arg.eval_size)]

                    mrr, hits, ranks = util.eval(
                        model=model, valset=testsub, truedicts=truedicts, n=len(i2n),
                        batch_size=arg.test_batch, verbose=True)

                    if arg.check_simple: # double-check using a separate, slower implementation
                        mrrs, hitss, rankss = util.eval_simple(
                            model=model, valset=testsub, alltriples=alltriples, n=len(i2n), verbose=True)

                        assert ranks == rankss
                        assert mrr == mrrs

                    print(f'epoch {e}: MRR {mrr:.4}\t hits@1 {hits[0]:.4}\t  hits@3 {hits[1]:.4}\t  hits@10 {hits[2]:.4}')

                    tbw.add_scalar('biases/mrr', mrr, e)
                    tbw.add_scalar('biases/h@1', hits[0], e)
                    tbw.add_scalar('biases/h@3', hits[1], e)
                    tbw.add_scalar('biases/h@10', hits[2], e)

                    if sched is not None:
                        sched.step(mrr) # reduce lr if mrr stalls

        test_mrrs.append(mrr)

    print('training finished.')

    temrrs = torch.tensor(test_mrrs)
    print(f'mean test MRR    {temrrs.mean():.3} ({temrrs.std():.3})  \t{test_mrrs}')
def extract_headline(doc, url):

    logging.debug("extracting headline")

    candidates = {}

    for h in util.tags(doc, 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'div', 'span',
                       'td', 'th'):
        score = 1
        txt = unicode(h.text_content()).strip()
        txt = u' '.join(txt.split())
        if len(txt) == 0:
            continue

        txt_norm = util.normalise_text(txt)

        if len(txt) >= 500:
            continue

        logging.debug(" headline: consider %s '%s'" % (
            h.tag,
            txt,
        ))

        # TODO: should run all these tests over a real corpus of articles
        # and work out proper probability-based scoring!

        # TEST: length of headline
        # TODO - get a proper headline-length frequency curve from
        # journalisted and score according to probability
        if len(txt) >= 20 and len(txt) < 60:
            logging.debug("  len in [20,60)")
            score += 1
        elif len(txt) >= 25 and len(txt) < 40:
            logging.debug("  len in [25,40)")
            score += 2

        if h.tag in ('h1', 'h2', 'h3', 'h4'):
            logging.debug("  significant heading (%s)" % (h.tag, ))
            score += 2
        if h.tag in ('span', 'td'):
            logging.debug("  -2 less headline-y element (%s)" % (h.tag, ))
            score -= 2

        # TEST: does it appear in <title> text?
        title = unicode(getattr(doc.find('.//title'), 'text', ''))
        if title is not None:
            if txt_norm in util.normalise_text(title):
                logging.debug("  appears in <title>")
                score += 3

        # TEST: likely-looking class or id
        if pats.headline['classes'].search(h.get('class', '')):
            logging.debug("  likely class")
            score += 2
        if pats.headline['classes'].search(h.get('id', '')):
            logging.debug("  likely id")
            score += 2

        # TEST: does it appear in likely looking <meta> tags?
        # eg:
        # <meta property="og:title" content="Dementia checks at age 75 urged"/>
        # <meta name="Headline" content="Dementia checks at age 75 urged"/>
        for meta in doc.findall('.//meta'):
            n = meta.get('name', meta.get('property', ''))
            if pats.headline['metatags'].search(n):
                meta_content = util.normalise_text(
                    unicode(meta.get('content', '')))
                if meta_content != '':
                    if txt_norm == meta_content:
                        logging.debug("  match meta")
                        score += 3
                    elif txt_norm in meta_content:
                        logging.debug("  contained by meta")
                        score += 1

        # TEST: does it match slug part of url?
        slug = re.split('[-_]', util.get_slug(url).lower())
        parts = [util.normalise_text(part) for part in txt.split()]
        parts = [part for part in parts if part != '']
        if len(parts) > 1:
            matched = [part for part in parts if part in slug]

            value = (5.0 * len(matched) / len(parts))  # max 5 points
            if value > 0:
                logging.debug("  match slug (%01f)" % (value, ))
                score += value

        # TODO: other possible tests
        # TEST: is it near top of article container?
        # TEST: is it just above article container?
        # TEST: is it non-complex html (anything more complex than <a>)
        # TEST: is it outside likely sidebar elements?

        if txt not in candidates or score > candidates[txt]:
            candidates[txt] = {
                'txt': txt,
                'score': score,
                'sourceline': h.sourceline,
                'node': h
            }

    if not candidates:
        return None

    # sort
    out = sorted(candidates.items(),
                 key=lambda item: item[1]['score'],
                 reverse=True)

    #pprint(out[:5])
    return out[0][1]
Exemple #5
0
def go(arg):

    global repeats
    repeats = arg.repeats

    tbdir = arg.tb_dir if arg.tb_dir is not None else os.path.join('./runs', get_slug(arg))[:250]
    tbw = SummaryWriter(log_dir=tbdir)

    dev = 'cuda' if torch.cuda.is_available() else 'cpu'

    test_mrrs = []

    train, val, test, (n2i, i2n), (r2i, i2r) = \
        kgmodels.load_lp(arg.name)

    # set of all triples (for filtering)
    alltriples = set()
    for s, p, o in torch.cat([train, val, test], dim=0):
        s, p, o = s.item(), p.item(), o.item()

        alltriples.add((s, p, o))

    truedicts = util.truedicts(alltriples)

    if arg.final:
        train, test = torch.cat([train, val], dim=0), test
    else:
        train, test = train, val

    subjects   = list({s for s, _, _ in train})
    predicates = list({p for _, p, _ in train})
    objects    = list({o for _, _, o in train})
    ccandidates = (subjects, predicates, objects)

    print(len(i2n), 'nodes')
    print(len(i2r), 'relations')
    print(train.size(0), 'training triples')
    print(test.size(0), 'test triples')
    print(train.size(0) + test.size(0), 'total triples')

    for r in tqdm.trange(repeats) if repeats > 1 else range(repeats):

        """
        Define model
        """
        model = kgmodels.LPShallow(
            triples=train, n=len(i2n), r=len(i2r), embedding=arg.emb, biases=arg.biases,
            edropout = arg.edo, rdropout=arg.rdo, decoder=arg.decoder)

        if torch.cuda.is_available():
            prt('Using CUDA.')
            model.cuda()

        if arg.opt == 'adam':
            opt = torch.optim.Adam(model.parameters(), lr=arg.lr)
        elif arg.opt == 'adamw':
            opt = torch.optim.AdamW(model.parameters(), lr=arg.lr)
        elif arg.opt == 'adagrad':
            opt = torch.optim.Adagrad(model.parameters(), lr=arg.lr)
        elif arg.opt == 'sgd':
            opt = torch.optim.SGD(model.parameters(), lr=arg.lr, nesterov=True, momentum=arg.momentum)
        else:
            raise Exception()

        sched = torch.optim.lr_scheduler.ReduceLROnPlateau(patience=arg.patience, optimizer=opt, mode='max', factor=0.95, threshold=0.0001) \
            if arg.sched else None
        #-- defaults taken from libkge

        # nr of negatives sampled
        weight = torch.tensor([arg.nweight, 1.0], device=d()) if arg.nweight else None

        seen = 0
        for e in range(arg.epochs):

            seeni, sumloss = 0, 0.0

            for fr in trange(0, train.size(0), arg.batch):

                tic()
                model.train(True)

                # if arg.limit is not None and seeni > arg.limit:
                #     break

                # if torch.cuda.is_available() and random.random() < 0.01:
                #     print(f'\nPeak gpu memory use is {torch.cuda.max_memory_cached() / 1e9:.2} Gb')

                to = min(train.size(0), fr + arg.batch)

                with torch.no_grad():
                    positives = train[fr:to]

                    b, _ = positives.size()

                    # # sample negatives
                    # if arg.corrupt_global: # global corruption (sample random true triples to corrupt)
                    #     indices = torch.randint(size=(b*ng,), low=0, high=train.size(0))
                    #     negatives = train[indices, :].view(b, ng, 3) # -- triples to be corrupted
                    #
                    # else: # local corruption (directly corrupt the current batch)
                    #     negatives = positives.clone()[:, None, :].expand(b, ng, 3).contiguous()


                    ttriples = []
                    for target, ng in zip([0, 1, 2], arg.negative_rate):
                        if ng > 0:

                            negatives = positives.clone()[:, None, :].expand(b, ng, 3).contiguous()
                            corrupt_one(negatives, ccandidates[target] if arg.limit_negatives else range(len(i2n)), target)

                            ttriples.append(torch.cat([positives[:, None, :], negatives], dim=1))

                    triples = torch.cat(ttriples, dim=0)

                    b, _, _ = triples.size()

                    if arg.loss == 'bce':
                        labels = torch.cat([torch.ones(b, 1), torch.zeros(b, ng)], dim=1)
                    elif arg.loss == 'ce':
                        labels = torch.zeros(b, dtype=torch.long)
                        # -- CE loss treats the problem as a multiclass classification problem: for a positive triple,
                        #    together with its k corruptions, identify which is the true triple. This is always triple 0.
                        #    (It may seem like the model could easily cheat by always choosing triple 0, but the score
                        #    function is order equivariant, so it can't choose by ordering.)

                    if torch.cuda.is_available():
                        triples = triples.cuda()
                        labels = labels.cuda()

                opt.zero_grad()

                out = model(triples)

                if arg.loss == 'bce':
                    loss = F.binary_cross_entropy_with_logits(out, labels, weight=weight, reduction=arg.lred)
                elif arg.loss == 'ce':
                    loss = F.cross_entropy(out, labels, reduction=arg.lred)

                if arg.reg_eweight is not None:
                    loss = loss + model.penalty(which='entities', p=arg.reg_exp, rweight=arg.reg_eweight)

                if arg.reg_rweight is not None:
                    loss = loss + model.penalty(which='relations', p=arg.reg_exp, rweight=arg.reg_rweight)

                assert not torch.isnan(loss), 'Loss has become NaN'

                loss.backward()

                sumloss += float(loss.item())

                #print('emean: ', model.relations.grad.mean().item())

                opt.step()

                seen += b; seeni += b
                tbw.add_scalar('biases/train_loss', float(loss.item()), seen)

            # Evaluate
            if ((e+1) % arg.eval_int == 0) or e == arg.epochs - 1:

                with torch.no_grad():

                    model.train(False)

                    if arg.eval_size is None:
                        testsub = test
                    else:
                        testsub = test[random.sample(range(test.size(0)), k=arg.eval_size)]

                    mrr, hits, ranks = util.eval_batch(
                        model=model, valset=testsub, truedicts=truedicts, n=len(i2n),
                        batch_size=arg.test_batch, verbose=True)

                    if arg.check_simple:
                        mrrs, hitss, rankss = util.eval_simple(
                            model=model, valset=testsub, alltriples=alltriples, n=len(i2n), verbose=True)

                        assert ranks == rankss
                        assert mrr == mrrs

                    print(f'epoch {e}: MRR {mrr:.4}\t hits@1 {hits[0]:.4}\t  hits@3 {hits[1]:.4}\t  hits@10 {hits[2]:.4}')
                    print(f'   ranks : {ranks[:10]}')

                    print('len check', len(ranks), len(testsub))

                    tbw.add_scalar('biases/mrr', mrr, e)
                    tbw.add_scalar('biases/h@1', hits[0], e)
                    tbw.add_scalar('biases/h@3', hits[1], e)
                    tbw.add_scalar('biases/h@10', hits[2], e)

                    if sched is not None:
                        sched.step(mrr) # reduce lr if mrr stalls

        test_mrrs.append(mrr)

    print('training finished.')

    temrrs = torch.tensor(test_mrrs)
    print(f'mean test MRR    {temrrs.mean():.3} ({temrrs.std():.3})  \t{test_mrrs}')