示例#1
0
def get_model_info(fname):
    chkpnt = torch.load(fname)
    dset = chkpnt['conf']['dset']
    if not os.path.exists(dset):
        raise ValueError("Can't find dset!")
    data_format = 'hdf5' if dset.endswith('.h5') else 'csv'
    # dset = load_adjacency_matrix(dset, 'hdf5')
    dset = load_adjacency_matrix(dset, data_format)

    info_dict = {}
    # info_dict['dset'] = dset
    info_dict['idmap'] = dset['idmap']
    info_dict['embeddings'] = chkpnt['embeddings'].numpy()
    return info_dict
def main(embdir='./emb150/',
         iterations=300,
         dataset='./wordnet/noun_closure.csv'):
    # set default tensor type
    torch.set_default_tensor_type('torch.DoubleTensor')
    # set device
    device = torch.device('cuda:0')

    format = 'hdf5' if dataset.endswith('.h5') else 'csv'
    dset = load_adjacency_matrix(dataset, format=format)

    sample_size = len(dset['ids'])
    sample = np.random.choice(len(dset['ids']),
                              size=sample_size,
                              replace=False)

    adj = {}
    print('calc adj...')
    for i in sample:
        end = dset['offsets'][i + 1] if i + 1 < len(dset['offsets']) \
            else len(dset['neighbors'])
        adj[i] = set(dset['neighbors'][dset['offsets'][i]:end])
    print('calc done...')

    manifolds = [
        'euclidean',
        # 'transe',
        'poincare',
        'lorentz'
    ]
    dimensions = [5, 10, 20, 50, 100, 200]

    for mani in manifolds:
        for dim in dimensions:
            key_json = '%s%d' % (mani, dim)
            for cp in range(9, iterations, 10):
                infile = 'emb_%s_%d.bin.%d' % (mani, dim, cp)
                fpath = os.path.join(embdir, infile)
                print('load %s...' % fpath)
                if os.path.exists(fpath):
                    chkpnt, manifold, lt, meanrank, maprank, elapsed = calc_rank_map(
                        fpath=fpath, adj=adj)
                    ret = save_snapshot(chkpnt=chkpnt,
                                        fpath=fpath,
                                        manifold=manifold,
                                        lt=lt,
                                        elapsed=elapsed,
                                        loss=0,
                                        meanrank=meanrank,
                                        maprank=maprank)
示例#3
0
def _load(opt, log, manifold):
    if "csv" in opt.dset:
        log.info("Using edge list dataloader")
        idx, objects, weights = load_edge_list(opt.dset, opt.sym)
        model, data, model_name = initialize(
            manifold,
            idx,
            objects,
            weights,
            sparse=opt.sparse,
            dim=opt.dim,
            negs=opt.negs,
            batch_size=opt.batchsize,
            burnin=opt.burnin,
            dampening=opt.dampening,
        )
    else:
        log.info("Using adjacency matrix dataloader")
        dset = load_adjacency_matrix(opt.dset, "hdf5")
        # noinspection PyArgumentList
        data = AdjacencyDataset(
            dset,
            opt.negs,
            opt.batchsize,
            burnin=opt.burnin > 0,
            sample_dampening=opt.dampening,
        )
        model = Embedding(data.N, opt.dim, manifold, sparse=opt.sparse)

    data.neg_multiplier = opt.neg_multiplier
    log.info(f"conf: {json.dumps(vars(opt))}")
    if opt.checkpoint and not opt.fresh:
        log.info("using loaded checkpoint")
        try:
            model.load_weights(opt.checkpoint)
        except Exception:
            log.info(f"not loading existing weights for {opt.checkpoint}")
    else:
        log.info("starting with fresh model")
    return model, data
示例#4
0
def main():    
    parser = argparse.ArgumentParser(description='Train Hyperbolic Embeddings')
    parser.add_argument('-dset', type=str, required=True,
                        help='Dataset identifier')
    parser.add_argument('-dim', type=int, default=20,
                        help='Embedding dimension')
    parser.add_argument('-com_n', type=int, default=2,
                        help='Embedding components number')
    parser.add_argument('-manifold', type=str, default='lorentz',
                        choices=MANIFOLDS.keys(), help='Embedding manifold')
    parser.add_argument('-lr', type=float, default=1000,
                        help='Learning rate')
    parser.add_argument('-epochs', type=int, default=100,
                        help='Number of epochs')
    parser.add_argument('-batchsize', type=int, default=12800,
                        help='Batchsize')
    parser.add_argument('-negs', type=int, default=50,
                        help='Number of negatives')
    parser.add_argument('-burnin', type=int, default=20,
                        help='Epochs of burn in')
    parser.add_argument('-dampening', type=float, default=0.75,
                        help='Sample dampening during burnin')
    parser.add_argument('-ndproc', type=int, default=8,
                        help='Number of data loading processes')
    parser.add_argument('-eval_each', type=int, default=1,
                        help='Run evaluation every n-th epoch')
    parser.add_argument('-debug', action='store_true', default=False,
                        help='Print debuggin output')
    parser.add_argument('-gpu', default=-1, type=int,
                        help='Which GPU to run on (-1 for no gpu)')
    parser.add_argument('-sym', action='store_true', default=False,
                        help='Symmetrize dataset')
    parser.add_argument('-maxnorm', '-no-maxnorm', default='500000',
                        action=Unsettable, type=int)
    parser.add_argument('-sparse', default=False, action='store_true',
                        help='Use sparse gradients for embedding table')
    parser.add_argument('-burnin_multiplier', default=0.01, type=float)
    parser.add_argument('-neg_multiplier', default=1.0, type=float)
    parser.add_argument('-quiet', action='store_true', default=True)
    parser.add_argument('-lr_type', choices=['scale', 'constant'], default='constant')
    parser.add_argument('-train_threads', type=int, default=1,
                        help='Number of threads to use in training')
    parser.add_argument('-eval_embedding', default=False, help='path for the embedding to be evaluated')
    opt = parser.parse_args()
    
    if 'LTiling' in opt.manifold:
        opt.nor = 'LTiling'
        opt.norevery = 20
        opt.stre = 50
    elif 'HTiling' in opt.manifold:
        opt.nor = 'HTiling'
        opt.norevery = 1
        opt.stre = 0
    else:
        opt.nor = 'none'

    # setup debugging and logigng
    log_level = logging.DEBUG if opt.debug else logging.INFO
    log = logging.getLogger('tiling model')
    logging.basicConfig(level=log_level, format='%(message)s', stream=sys.stdout)

    # set default tensor type
    th.set_default_tensor_type('torch.DoubleTensor')####FloatTensor DoubleTensor
    # set device
    # device = th.device(f'cuda:{opt.gpu}' if opt.gpu >= 0 else 'cpu')
    device = th.device('cpu')

    # select manifold to optimize on
    manifold = MANIFOLDS[opt.manifold](debug=opt.debug, max_norm=opt.maxnorm, com_n=opt.com_n)
    if 'Halfspace' not in opt.manifold:
        opt.dim = manifold.dim(opt.dim)

    if 'csv' in opt.dset:
        log.info('Using edge list dataloader')
        idx, objects, weights = load_edge_list(opt.dset, opt.sym)
        model, data, model_name, conf = initialize(
            manifold, opt, idx, objects, weights, sparse=opt.sparse
        )
    else:
        log.info('Using adjacency matrix dataloader')
        dset = load_adjacency_matrix(opt.dset, 'hdf5')
        log.info('Setting up dataset...')
        data = AdjacencyDataset(dset, opt.negs, opt.batchsize, opt.ndproc,
            opt.burnin > 0, sample_dampening=opt.dampening)
        model = Embedding(data.N, opt.dim, manifold, sparse=opt.sparse, com_n=opt.com_n)
        objects = dset['objects']
    print('the total dimension', model.lt.weight.data.size(-1), 'com_n', opt.com_n)
    # set burnin parameters
    data.neg_multiplier = opt.neg_multiplier
    train._lr_multiplier = opt.burnin_multiplier
    # Build config string for log
    log.info(f'json_conf: {json.dumps(vars(opt))}')
    if opt.lr_type == 'scale':
        opt.lr = opt.lr * opt.batchsize

    # setup optimizer
    optimizer = RiemannianSGD(model.optim_params(manifold), lr=opt.lr)
    opt.epoch_start = 0
    adj = {}
    for inputs, _ in data:
        for row in inputs:
            x = row[0].item()
            y = row[1].item()
            if x in adj:
                adj[x].add(y)
            else:
                adj[x] = {y}
    if not opt.eval_embedding:
        opt.adj = adj
        model = model.to(device)
        if hasattr(model, 'w_avg'):
            model.w_avg = model.w_avg.to(device)
        if opt.train_threads > 1:
            threads = []
            model = model.share_memory()
            if 'LTiling' in opt.manifold:
                model.int_matrix.share_memory_()
            kwargs = {'progress' : not opt.quiet}
            for i in range(opt.train_threads):
                args = (i, device, model, data, optimizer, opt, log)
                threads.append(mp.Process(target=train.train, args=args, kwargs=kwargs))
                threads[-1].start()
            [t.join() for t in threads]
        else:
            train.train(device, model, data, optimizer, opt, log, progress=not opt.quiet)
    else:
        model = th.load(opt.eval_embedding, map_location='cpu')['embeddings']

    if 'LTiling' in opt.manifold:
        meanrank, maprank = eval_reconstruction(adj, model.lt.weight.data.clone(), manifold.distance, lt_int_matrix = model.int_matrix.data.clone(), workers = opt.ndproc)
        sqnorms = manifold.pnorm(model.lt.weight.data.clone(), model.int_matrix.data.clone())
    else:
        meanrank, maprank = eval_reconstruction(adj, model.lt.weight.data.clone(), manifold.distance, workers = opt.ndproc)
        sqnorms = manifold.pnorm(model.lt.weight.data.clone())
    
    log.info(
        'json_stats final test: {'
        f'"sqnorm_min": {sqnorms.min().item()}, '
        f'"sqnorm_avg": {sqnorms.mean().item()}, '
        f'"sqnorm_max": {sqnorms.max().item()}, '
        f'"mean_rank": {meanrank}, '
        f'"map": {maprank}, '
        '}'
    )
    print(model.lt.weight.data[0])
np.random.seed(42)

parser = argparse.ArgumentParser()
parser.add_argument('file', help='Path to checkpoint')
parser.add_argument('-workers', default=1, type=int, help='Number of workers')
parser.add_argument('-sample', type=int, help='Sample size')
parser.add_argument('-quiet', action='store_true', default=False)
args = parser.parse_args()

chkpnt = torch.load(args.file)
dset = chkpnt['conf']['dset']
if not os.path.exists(dset):
    raise ValueError("Can't find dset!")

format = 'hdf5' if dset.endswith('.h5') else 'csv'
dset = load_adjacency_matrix(dset, format)

sample_size = args.sample or len(dset['ids'])
sample = np.random.choice(len(dset['ids']), size=sample_size, replace=False)

adj = {}

for i in sample:
    end = dset['offsets'][i + 1] if i + 1 < len(dset['offsets']) \
        else len(dset['neighbors'])
    adj[i] = set(dset['neighbors'][dset['offsets'][i]:end])

manifold = MANIFOLDS[chkpnt['conf']['manifold']]()

lt = chkpnt['embeddings']
if not isinstance(lt, torch.Tensor):
示例#6
0
parser = argparse.ArgumentParser()
parser.add_argument('file', help='Path to checkpoint')
parser.add_argument('-workers', default=1, type=int, help='Number of workers')
parser.add_argument('-sample', type=int, help='Sample size')
parser.add_argument('-quiet', action='store_true', default=False)
args = parser.parse_args()


chkpnt = torch.load(args.file)
dset = chkpnt['conf']['dset']
if not os.path.exists(dset):
    raise ValueError("Can't find dset!")

format = 'hdf5' if dset.endswith('.h5') else 'csv'
dset = load_adjacency_matrix(dset, format, objects=chkpnt['objects'])

sample_size = args.sample or len(dset['ids'])
sample = np.random.choice(len(dset['ids']), size=sample_size, replace=False)

adj = {}

for i in sample:
    end = dset['offsets'][i + 1] if i + 1 < len(dset['offsets']) \
        else len(dset['neighbors'])
    adj[dset['ids'][i]] = set(dset['neighbors'][dset['offsets'][i]:end])
manifold = MANIFOLDS[chkpnt['conf']['manifold']]()

lt = chkpnt['embeddings']
if not isinstance(lt, torch.Tensor):
    lt = torch.from_numpy(lt).cuda()
示例#7
0
def main():
    parser = argparse.ArgumentParser(description='Train Hyperbolic Embeddings')
    parser.add_argument('-checkpoint', default='/tmp/hype_embeddings.pth',
                        help='Where to store the model checkpoint')
    parser.add_argument('-dset', type=str, required=True,
                        help='Dataset identifier')
    parser.add_argument('-dim', type=int, default=20,
                        help='Embedding dimension')
    parser.add_argument('-manifold', type=str, default='lorentz',
                        choices=MANIFOLDS.keys(), help='Embedding manifold')
    parser.add_argument('-lr', type=float, default=1000,
                        help='Learning rate')
    parser.add_argument('-epochs', type=int, default=100,
                        help='Number of epochs')
    parser.add_argument('-batchsize', type=int, default=12800,
                        help='Batchsize')
    parser.add_argument('-negs', type=int, default=50,
                        help='Number of negatives')
    parser.add_argument('-burnin', type=int, default=20,
                        help='Epochs of burn in')
    parser.add_argument('-dampening', type=float, default=0.75,
                        help='Sample dampening during burnin')
    parser.add_argument('-ndproc', type=int, default=8,
                        help='Number of data loading processes')
    parser.add_argument('-eval_each', type=int, default=1,
                        help='Run evaluation every n-th epoch')
    parser.add_argument('-fresh', action='store_true', default=False,
                        help='Override checkpoint')
    parser.add_argument('-debug', action='store_true', default=False,
                        help='Print debuggin output')
    parser.add_argument('-gpu', default=0, type=int,
                        help='Which GPU to run on (-1 for no gpu)')
    parser.add_argument('-sym', action='store_true', default=False,
                        help='Symmetrize dataset')
    parser.add_argument('-maxnorm', '-no-maxnorm', default='500000',
                        action=Unsettable, type=int)
    parser.add_argument('-sparse', default=False, action='store_true',
                        help='Use sparse gradients for embedding table')
    parser.add_argument('-burnin_multiplier', default=0.01, type=float)
    parser.add_argument('-neg_multiplier', default=1.0, type=float)
    parser.add_argument('-quiet', action='store_true', default=False)
    parser.add_argument('-lr_type', choices=['scale', 'constant'], default='constant')
    parser.add_argument('-train_threads', type=int, default=1,
                        help='Number of threads to use in training')
    opt = parser.parse_args()

    # setup debugging and logigng
    log_level = logging.DEBUG if opt.debug else logging.INFO
    log = logging.getLogger('lorentz')
    logging.basicConfig(level=log_level, format='%(message)s', stream=sys.stdout)

    if opt.gpu >= 0 and opt.train_threads > 1:
        opt.gpu = -1
        log.warning(f'Specified hogwild training with GPU, defaulting to CPU...')


    # set default tensor type
    th.set_default_tensor_type('torch.DoubleTensor')
    # set device
    device = th.device(f'cuda:{opt.gpu}' if opt.gpu >= 0 else 'cpu')

    # select manifold to optimize on
    manifold = MANIFOLDS[opt.manifold](debug=opt.debug, max_norm=opt.maxnorm)
    opt.dim = manifold.dim(opt.dim)

    if 'csv' in opt.dset:
        log.info('Using edge list dataloader')
        idx, objects, weights = load_edge_list(opt.dset, opt.sym)
        model, data, model_name, conf = initialize(
            manifold, opt, idx, objects, weights, sparse=opt.sparse
        )
    else:
        log.info('Using adjacency matrix dataloader')
        dset = load_adjacency_matrix(opt.dset, 'hdf5')
        log.info('Setting up dataset...')
        data = AdjacencyDataset(dset, opt.negs, opt.batchsize, opt.ndproc,
            opt.burnin > 0, sample_dampening=opt.dampening)
        model = Embedding(data.N, opt.dim, manifold, sparse=opt.sparse)
        objects = dset['objects']

    # set burnin parameters
    data.neg_multiplier = opt.neg_multiplier
    train._lr_multiplier = opt.burnin_multiplier

    # Build config string for log
    log.info(f'json_conf: {json.dumps(vars(opt))}')

    if opt.lr_type == 'scale':
        opt.lr = opt.lr * opt.batchsize

    # setup optimizer
    optimizer = RiemannianSGD(model.optim_params(manifold), lr=opt.lr)

    # setup checkpoint
    checkpoint = LocalCheckpoint(
        opt.checkpoint,
        include_in_all={'conf' : vars(opt), 'objects' : objects},
        start_fresh=opt.fresh
    )

    # get state from checkpoint
    state = checkpoint.initialize({'epoch': 0, 'model': model.state_dict()})
    model.load_state_dict(state['model'])
    opt.epoch_start = state['epoch']
    
    ### Justin:
    print("Model State:")
    for parms in model.state_dict():
        print(parms,"\t",model.state_dict()[parms].size() )
        np_parms = model.state_dict()[parms].numpy()
        np.savetxt('pe.coors.txt',np_parms)
    ### EOF Justin

    adj = {}
    for inputs, _ in data:
        for row in inputs:
            x = row[0].item()
            y = row[1].item()
            if x in adj:
                adj[x].add(y)
            else:
                adj[x] = {y}

    controlQ, logQ = mp.Queue(), mp.Queue()
    control_thread = mp.Process(target=async_eval, args=(adj, controlQ, logQ, opt))
    control_thread.start()

    # control closure
    def control(model, epoch, elapsed, loss):
        """
        Control thread to evaluate embedding
        """
        lt = model.w_avg if hasattr(model, 'w_avg') else model.lt.weight.data
        manifold.normalize(lt)

        checkpoint.path = f'{opt.checkpoint}.{epoch}'
        checkpoint.save({
            'model': model.state_dict(),
            'embeddings': lt,
            'epoch': epoch +1, ### Justin
            'manifold': opt.manifold,
        })

        controlQ.put((epoch, elapsed, loss, checkpoint.path))

        while not logQ.empty():
            lmsg, pth = logQ.get()
            shutil.move(pth, opt.checkpoint)
            log.info(f'json_stats: {json.dumps(lmsg)}')

    control.checkpoint = True
    model = model.to(device)
    if hasattr(model, 'w_avg'):
        model.w_avg = model.w_avg.to(device)
    if opt.train_threads > 1:
        threads = []
        model = model.share_memory()
        args = (device, model, data, optimizer, opt, log)
        kwargs = {'ctrl': control, 'progress' : not opt.quiet}
        for i in range(opt.train_threads):
            kwargs['rank'] = i
            threads.append(mp.Process(target=train.train, args=args, kwargs=kwargs))
            threads[-1].start()
        [t.join() for t in threads]
    else:
        train.train(device, model, data, optimizer, opt, log, ctrl=control,
            progress=not opt.quiet)
    controlQ.put(None)
    control_thread.join()
    while not logQ.empty():
        lmsg, pth = logQ.get()
        shutil.move(pth, opt.checkpoint)
        log.info(f'json_stats: {json.dumps(lmsg)}')
示例#8
0
parser = argparse.ArgumentParser()
parser.add_argument('file', help='Path to checkpoint')
parser.add_argument('-workers', default=1, type=int, help='Number of workers')
parser.add_argument('-sample', type=int, help='Sample size')
parser.add_argument('-quiet', action='store_true', default=False)
args = parser.parse_args()

chkpnt = torch.load(args.file)
dset = chkpnt['conf']['dset']
if not os.path.exists(dset):
    raise ValueError("Can't find dset!")

data_format = 'hdf5' if dset.endswith('.h5') else 'csv'
#dset = load_adjacency_matrix(dset, 'hdf5')
dset = load_adjacency_matrix(dset, data_format)

sample_size = args.sample or len(dset['ids'])
sample = np.random.choice(len(dset['ids']), size=sample_size, replace=False)

adj = {}

for i in sample:
    end = dset['offsets'][i + 1] if i + 1 < len(dset['offsets']) \
        else len(dset['neighbors'])
    adj[i] = set(dset['neighbors'][dset['offsets'][i]:end])

manifold = MANIFOLDS[chkpnt['conf']['manifold']]()

lt = chkpnt['embeddings']
if not isinstance(lt, torch.Tensor):
np.random.seed(42)

parser = argparse.ArgumentParser()
parser.add_argument('file', help='Path to checkpoint')
parser.add_argument('-workers', default=1, type=int, help='Number of workers')
parser.add_argument('-sample', type=int, help='Sample size')
parser.add_argument('-quiet', action='store_true', default=False)
args = parser.parse_args()

chkpnt = torch.load(args.file)
dset = chkpnt['conf']['dset']
if not os.path.exists(dset):
    raise ValueError("Can't find dset!")

format = 'hdf5' if dset.endswith('.h5') else 'csv'
dset = load_adjacency_matrix(dset, 'hdf5')

sample_size = args.sample or len(dset['ids'])
sample = np.random.choice(len(dset['ids']), size=sample_size, replace=False)

adj = {}

for i in sample:
    end = dset['offsets'][i + 1] if i + 1 < len(dset['offsets']) \
        else len(dset['neighbors'])
    adj[i] = set(dset['neighbors'][dset['offsets'][i]:end])

manifold = MANIFOLDS[chkpnt['conf']['manifold']]()

lt = chkpnt['embeddings']
if not isinstance(lt, torch.Tensor):
示例#10
0
def main():
    parser = argparse.ArgumentParser(description='Train Hyperbolic Embeddings')
    parser.add_argument('-checkpoint', default='/tmp/hype_embeddings.pth',
                        help='Where to store the model checkpoint')
    parser.add_argument('-dset', type=str, required=True,
                        help='Dataset identifier')
    parser.add_argument('-dim', type=int, default=20,
                        help='Embedding dimension')
    parser.add_argument('-manifold', type=str, default='lorentz',
                        choices=MANIFOLDS.keys())
    parser.add_argument('-model', type=str, default='distance',
                        choices=MODELS.keys(), help='Energy function model')
    parser.add_argument('-lr', type=float, default=1000,
                        help='Learning rate')
    parser.add_argument('-epochs', type=int, default=100,
                        help='Number of epochs')
    parser.add_argument('-batchsize', type=int, default=12800,
                        help='Batchsize')
    parser.add_argument('-negs', type=int, default=50,
                        help='Number of negatives')
    parser.add_argument('-burnin', type=int, default=20,
                        help='Epochs of burn in')
    parser.add_argument('-dampening', type=float, default=0.75,
                        help='Sample dampening during burnin')
    parser.add_argument('-ndproc', type=int, default=8,
                        help='Number of data loading processes')
    parser.add_argument('-eval_each', type=int, default=1,
                        help='Run evaluation every n-th epoch')
    parser.add_argument('-fresh', action='store_true', default=False,
                        help='Override checkpoint')
    parser.add_argument('-debug', action='store_true', default=False,
                        help='Print debuggin output')
    parser.add_argument('-gpu', default=0, type=int,
                        help='Which GPU to run on (-1 for no gpu)')
    parser.add_argument('-sym', action='store_true', default=False,
                        help='Symmetrize dataset')
    parser.add_argument('-maxnorm', '-no-maxnorm', default='500000',
                        action=Unsettable, type=int)
    parser.add_argument('-sparse', default=False, action='store_true',
                        help='Use sparse gradients for embedding table')
    parser.add_argument('-burnin_multiplier', default=0.01, type=float)
    parser.add_argument('-neg_multiplier', default=1.0, type=float)
    parser.add_argument('-quiet', action='store_true', default=False)
    parser.add_argument('-lr_type', choices=['scale', 'constant'], default='constant')
    parser.add_argument('-train_threads', type=int, default=1,
                        help='Number of threads to use in training')
    parser.add_argument('-margin', type=float, default=0.1, help='Hinge margin')
    parser.add_argument('-eval', choices=['reconstruction', 'hypernymy'],
                        default='reconstruction', help='Which type of eval to perform')
    opt = parser.parse_args()

    # setup debugging and logigng
    log_level = logging.DEBUG if opt.debug else logging.INFO
    log = logging.getLogger('poincare')
    logging.basicConfig(level=log_level, format='%(message)s', stream=sys.stdout)

    # attempt to find GPU
    if opt.gpu >= 0 and opt.train_threads > 1:
        opt.gpu = -1
        log.warning(f'Specified hogwild training with GPU, defaulting to CPU...')

    # set default tensor type
    if opt.gpu == -1:
        th.set_default_tensor_type('torch.DoubleTensor')
    if opt.gpu == 1:
        th.set_default_tensor_type('torch.cuda.DoubleTensor')

    # set device
    device = th.device(f'cuda:{opt.gpu-1}' if opt.gpu >= 0 else 'cpu')
    print(f"\n\n opt.gpu = {opt.gpu} \n DEVICE = {device} \n\n")

    # read data (edge set is fed as .csv in train_nouns.sh)
    if 'csv' in opt.dset:
        log.info('Using edge list dataloader')
        idx, objects, weights = load_edge_list(opt.dset, opt.sym)
        data = BatchedDataset(idx, objects, weights, opt.negs, opt.batchsize,
            opt.ndproc, opt.burnin > 0, opt.dampening)
    else:
        log.info('Using adjacency matrix dataloader')
        dset = load_adjacency_matrix(opt.dset, 'hdf5')
        log.info('Setting up dataset...')
        data = AdjacencyDataset(dset, opt.negs, opt.batchsize, opt.ndproc,
            opt.burnin > 0, sample_dampening=opt.dampening)
        objects = dset['objects']

    # create model - read buld_model fn in /hype/__init__.py to see how mfold,
    # dim, loss etc are set up. We store these in model below
    # (model is object of DistanceEnergyFunction class which inherits from EnergyFunction class)
    model = build_model(opt, len(objects))
    log.info(f'model is = {model}')
    # set burnin parameters
    data.neg_multiplier = opt.neg_multiplier
    train._lr_multiplier = opt.burnin_multiplier

    # Build config string for log
    log.info(f'json_conf: {json.dumps(vars(opt))}')

    # adjust lr (train_nouns.sh defines opt.lr_type as constant)
    if opt.lr_type == 'scale':
        opt.lr = opt.lr * opt.batchsize

    # Read model params dict. The model is DistanceEnergyFunction
    # (see hype/__init__.py for reason why this is the model)
    # Read EnergyFunction class to see what these params are - they are
    # the expected input to RiemannianSGD class
    log.info(f'\n\n------------------------------\nCheck expm, logm, ptransp defined for Poincare. \nBound method should belong to PoincareManifold not EuclideanManifold\n------------------------------\n\n')
    log.info(f'Model expm = {model.optim_params()[0]["expm"]}')
    log.info(f'Model logm = {model.optim_params()[0]["logm"]}')
    log.info(f'Model ptransp = {model.optim_params()[0]["ptransp"]}')
    log.info(f'Model rgrad = {model.optim_params()[0]["rgrad"]}')

    # setup optimizer
    optimizer = RiemannianSGD(model.optim_params(), lr=opt.lr)

    # setup checkpoint
    checkpoint = LocalCheckpoint(
        opt.checkpoint,
        include_in_all={'conf' : vars(opt), 'objects' : objects},
        start_fresh=opt.fresh
    )

    # get state from checkpoint
    state = checkpoint.initialize({'epoch': 0, 'model': model.state_dict()})
    model.load_state_dict(state['model'])
    opt.epoch_start = state['epoch']

    adj = {}
    for inputs, _ in data:
        for row in inputs:
            x = row[0].item()
            y = row[1].item()
            if x in adj:
                adj[x].add(y)
            else:
                adj[x] = {y}

    controlQ, logQ = mp.Queue(), mp.Queue()
    control_thread = mp.Process(target=async_eval, args=(adj, controlQ, logQ, opt))
    control_thread.start()

    # control closure
    def control(model, epoch, elapsed, loss):
        """
        Control thread to evaluate embedding
        """
        lt = model.w_avg if hasattr(model, 'w_avg') else model.lt.weight.data
        model.manifold.normalize(lt)

        checkpoint.path = f'{opt.checkpoint}.{epoch}'
        checkpoint.save({
            'model': model.state_dict(),
            'embeddings': lt,
            'epoch': epoch,
            'model_type': opt.model,
        })

        controlQ.put((epoch, elapsed, loss, checkpoint.path))

        while not logQ.empty():
            lmsg, pth = logQ.get()
            shutil.move(pth, opt.checkpoint)
            if lmsg['best']:
                shutil.copy(opt.checkpoint, opt.checkpoint + '.best')
            log.info(f'json_stats: {json.dumps(lmsg)}')

    control.checkpoint = True
    model = model.to(device)
    if hasattr(model, 'w_avg'):
        model.w_avg = model.w_avg.to(device)
    if opt.train_threads > 1:
        log.info("multi-threaded")
        threads = []
        model = model.share_memory()
        args = (device, model, data, optimizer, opt, log)
        kwargs = {'ctrl': control, 'progress' : not opt.quiet}
        for i in range(opt.train_threads):
            kwargs['rank'] = i
            threads.append(mp.Process(target=train.train, args=args, kwargs=kwargs))
            threads[-1].start()
        [t.join() for t in threads]
    else:
        log.info("single-threaded")
        train.train(device, model, data, optimizer, opt, log, ctrl=control,
            progress=not opt.quiet)
    controlQ.put(None)
    control_thread.join()
    while not logQ.empty():
        lmsg, pth = logQ.get()
        shutil.move(pth, opt.checkpoint)
        log.info(f'json_stats: {json.dumps(lmsg)}')