Exemple #1
0
def learn(dataset,
          dim=2,
          hyp=1,
          edim=1,
          euc=0,
          sdim=1,
          sph=0,
          scale=1.,
          riemann=False,
          learning_rate=1e-1,
          decay_length=1000,
          decay_step=1.0,
          momentum=0.0,
          tol=1e-8,
          epochs=100,
          burn_in=0,
          use_yellowfin=False,
          use_adagrad=False,
          resample_freq=1000,
          print_freq=1,
          model_save_file=None,
          model_load_file=None,
          batch_size=16,
          num_workers=None,
          lazy_generation=False,
          log_name=None,
          log=False,
          warm_start=None,
          learn_scale=False,
          checkpoint_freq=100,
          sample=1.,
          subsample=None,
          logloss=False,
          distloss=False,
          squareloss=False,
          symloss=False,
          exponential_rescale=None,
          extra_steps=1,
          use_svrg=False,
          T=10,
          use_hmds=False,
          visualize=False):
    # Log configuration
    formatter = logging.Formatter('%(asctime)s %(message)s')
    logging.basicConfig(
        level=logging.DEBUG,
        format='%(asctime)s %(message)s',
        datefmt='%FT%T',
    )
    if log_name is None and log:
        log_name = f"{os.path.splitext(dataset)[0]}.H{dim}-{hyp}.E{edim}-{euc}.S{sdim}-{sph}.lr{learning_rate}.log"
    if log_name is not None:
        logging.info(f"Logging to {log_name}")
        log = logging.getLogger()
        fh = logging.FileHandler(log_name)
        fh.setFormatter(formatter)
        log.addHandler(fh)
        #############
        loss_list = []
        #############

    logging.info(f"Commandline {sys.argv}")
    if model_save_file is None: logging.warning("No Model Save selected!")
    G = load_graph.load_graph(dataset)
    GM = nx.to_scipy_sparse_matrix(G, nodelist=list(range(G.order())))

    # grab scale if warm starting:
    if warm_start:
        scale = pandas.read_csv(warm_start, index_col=0).as_matrix()[0, -1]

    n = G.order()
    logging.info(f"Loaded Graph {dataset} with {n} nodes scale={scale}")

    if exponential_rescale is not None:
        # torch.exp(exponential_rescale * -d)
        def weight_fn(d):
            if d <= 2.0: return 5.0
            elif d > 4.0: return 0.01
            else: return 1.0
    else:

        def weight_fn(d):
            return 1.0

    Z, z = build_dataset(G, lazy_generation, sample, subsample, scale,
                         batch_size, weight_fn, num_workers)

    if model_load_file is not None:
        logging.info(f"Loading {model_load_file}...")
        m = torch.load(model_load_file).to(device)
        logging.info(
            f"Loaded scale {unwrap(m.scale())} {torch.sum(m.embedding().data)} {m.epoch}"
        )
    else:
        logging.info(f"Creating a fresh model warm_start?={warm_start}")

        m_init = None
        if warm_start:
            # load from DataFrame; assume that the julia combinatorial embedding has been saved
            ws_data = pandas.read_csv(warm_start, index_col=0).as_matrix()
            scale = ws_data[0, ws_data.shape[1] - 1]
            m_init = torch.DoubleTensor(ws_data[:,
                                                range(ws_data.shape[1] - 1)])
        elif use_hmds:
            # m_init = torch.DoubleTensor(mds_warmstart.get_normalized_hyperbolic(mds_warmstart.get_model(dataset,dim,scale)[1]))
            m_init = torch.DoubleTensor(
                mds_warmstart.get_model(dataset, dim, scale)[1])

        logging.info(
            f"\t Warmstarting? {warm_start} {m_init.size() if warm_start else None} {G.order()}"
        )
        # initial_scale = z.dataset.max_dist / 3.0
        # print("MAX DISTANCE", z.dataset.max_dist)
        # print("AVG DISTANCE", torch.mean(z.dataset.val_cache))
        initial_scale = 0.0
        m = ProductEmbedding(G.order(),
                             dim,
                             hyp,
                             edim,
                             euc,
                             sdim,
                             sph,
                             initialize=m_init,
                             learn_scale=learn_scale,
                             initial_scale=initial_scale,
                             logrel_loss=logloss,
                             dist_loss=distloss,
                             square_loss=squareloss,
                             sym_loss=symloss,
                             exponential_rescale=exponential_rescale,
                             riemann=riemann).to(device)
        m.normalize()
        m.epoch = 0
    logging.info(
        f"Constructed model with dim={dim} and epochs={m.epoch} isnan={np.any(np.isnan(m.embedding().cpu().data.numpy()))}"
    )

    if visualize:
        name = 'animations/' + f"{os.path.split(os.path.splitext(dataset)[0])[1]}.H{dim}-{hyp}.E{edim}-{euc}.S{sdim}-{sph}.lr{learning_rate}.ep{epochs}.seed{seed}"
        fig, ax, writer = vis.setup_plot(m=m, name=name, draw_circle=True)
    else:
        fig = None
        ax = None
        writer = None

    #
    # Build the Optimizer
    #
    # TODO: Redo this in a sensible way!!

    # per-parameter learning rates
    exp_params = [p for p in m.embed_params if p.use_exp]
    learn_params = [p for p in m.embed_params if not p.use_exp]
    hyp_params = [p for p in m.hyp_params if not p.use_exp]
    euc_params = [p for p in m.euc_params if not p.use_exp]
    sph_params = [p for p in m.sph_params if not p.use_exp]
    scale_params = m.scale_params
    # model_params = [{'params': m.embed_params}, {'params': m.scale_params, 'lr': 1e-4*learning_rate}]
    # model_params = [{'params': learn_params}, {'params': m.scale_params, 'lr': 1e-4*learning_rate}]
    model_params = [{
        'params': hyp_params
    }, {
        'params': euc_params
    }, {
        'params': sph_params,
        'lr': 0.1 * learning_rate
    }, {
        'params': m.scale_params,
        'lr': 1e-4 * learning_rate
    }]

    # opt = None
    if len(model_params) > 0:
        opt = torch.optim.SGD(model_params,
                              lr=learning_rate / 10,
                              momentum=momentum)
        # opt = torch.optim.SGD(learn_params, lr=learning_rate/10, momentum=momentum)
    # opt = torch.optim.SGD(model_params, lr=learning_rate/10, momentum=momentum)
    # exp = None
    # if len(exp_params) > 0:
    #     exp = torch.optim.SGD(exp_params, lr=1.0) # dummy for zeroing
    if len(scale_params) > 0:
        scale_opt = torch.optim.SGD(scale_params, lr=1e-3 * learning_rate)
        scale_decay = torch.optim.lr_scheduler.StepLR(scale_opt,
                                                      step_size=1,
                                                      gamma=.99)
    else:
        scale_opt = None
        scale_decay = None
    lr_burn_in = torch.optim.lr_scheduler.MultiStepLR(opt,
                                                      milestones=[burn_in],
                                                      gamma=10)
    # lr_decay = torch.optim.lr_scheduler.StepLR(opt, decay_length, decay_step) #TODO reconcile multiple LR schedulers
    if use_yellowfin:
        from yellowfin import YFOptimizer
        opt = YFOptimizer(model_params)

    if use_adagrad:
        opt = torch.optim.Adagrad(model_params)

    if use_svrg:
        from svrg import SVRG
        base_opt = torch.optim.Adagrad if use_adagrad else torch.optim.SGD
        opt = SVRG(m.parameters(),
                   lr=learning_rate,
                   T=T,
                   data_loader=z,
                   opt=base_opt)
        # TODO add ability for SVRG to take parameter groups

    logging.info(opt)

    # Log stats from import: when warmstarting, check that it matches Julia's stats
    logging.info(f"*** Initial Checkpoint. Computing Stats")
    major_stats(GM, n, m, lazy_generation, Z, z, fig, ax, writer, visualize,
                subsample)
    logging.info("*** End Initial Checkpoint\n")

    # track best stats
    best_loss = 1.0e10
    best_dist = 1.0e10
    best_wcdist = 1.0e10
    best_map = 0.0
    for i in range(m.epoch + 1, m.epoch + epochs + 1):
        lr_burn_in.step()
        # lr_decay.step()
        # scale_decay.step()
        # print(scale_opt.param_groups[0]['lr'])
        # for param_group in opt.param_groups:
        #     print(param_group['lr'])
        # print(type(opt.param_groups), opt.param_groups)

        l, n_edges = 0.0, 0.0  # track average loss per edge
        m.train(True)
        if use_svrg:
            for data in z:

                def closure(data=data, target=None):
                    _data = data if target is None else (data, target)
                    c = m.loss(_data.to(device))
                    c.backward()
                    return c.data[0]

                l += opt.step(closure)

                # Projection
                m.normalize()

        else:
            # scale_opt.zero_grad()
            for the_step in range(extra_steps):
                # Accumulate the gradient
                for u in z:
                    # Zero out the gradients
                    # if opt is not None: opt.zero_grad() # This is handled by the SVRG.
                    # if exp is not None: exp.zero_grad()
                    opt.zero_grad()
                    for p in exp_params:
                        if p.grad is not None:
                            p.grad.detach_()
                            p.grad.zero_()
                    # Compute loss
                    _loss = m.loss(cu_var(u))
                    _loss.backward()
                    l += _loss.item() * u[0].size(0)
                    # print(weight)
                    n_edges += u[0].size(0)
                    # modify gradients if necessary
                    RParameter.correct_metric(m.parameters())
                    # step
                    opt.step()
                    for p in exp_params:
                        lr = opt.param_groups[0]['lr']
                        p.exp(lr)
                    # Projection
                    m.normalize()
            # scale_opt.step()

        l /= n_edges

        # m.epoch refers to num of training epochs finished
        m.epoch += 1

        # Logging code
        # if l < tol:
        #         logging.info("Found a {l} solution. Done at iteration {i}!")
        #         break
        if i % print_freq == 0:
            logging.info(f"{i} loss={l}")
            ############
            if log_name is not None:
                loss_list.append(l)
            #############

        if (i <= burn_in and i %
            (checkpoint_freq / 5) == 0) or i % checkpoint_freq == 0:
            logging.info(f"\n*** Major Checkpoint. Computing Stats and Saving")
            avg_dist, wc_dist, me, mc, mapscore = major_stats(
                GM, n, m, True, Z, z, fig, ax, writer, visualize, subsample)
            best_loss = min(best_loss, l)
            best_dist = min(best_dist, avg_dist)
            best_wcdist = min(best_wcdist, wc_dist)
            best_map = max(best_map, mapscore)
            if model_save_file is not None:
                fname = f"{model_save_file}.{m.epoch}"
                logging.info(
                    f"Saving model into {fname} {torch.sum(m.embedding().data)} "
                )
                torch.save(m, fname)
            logging.info("*** End Major Checkpoint\n")
        if i % resample_freq == 0:
            if sample < 1. or subsample is not None:
                Z, z = build_dataset(G, lazy_generation, sample, subsample,
                                     scale, batch_size, weight_fn, num_workers)

    logging.info(f"final loss={l}")
    logging.info(
        f"best loss={best_loss}, distortion={best_dist}, map={best_map}, wc_dist={best_wcdist}"
    )

    final_dist, final_wc, final_me, final_mc, final_map = major_stats(
        GM, n, m, lazy_generation, Z, z, fig, ax, writer, False, subsample)

    if log_name is not None:
        ###
        with open(log_name + "_loss.stat", "w") as f:
            for loss in loss_list:
                f.write("%f\n" % loss)
        ###

        with open(log_name + '_final.stat', "w") as f:
            f.write("Best-loss MAP dist wc Final-loss MAP dist wc me mc\n")
            f.write(
                f"{best_loss:10.6f} {best_map:8.4f} {best_dist:8.4f} {best_wcdist:8.4f} {l:10.6f} {final_map:8.4f} {final_dist:8.4f} {final_wc:8.4f} {final_me:8.4f} {final_mc:8.4f}"
            )

    if visualize:
        writer.finish()

    if model_save_file is not None:
        fname = f"{model_save_file}.final"
        logging.info(
            f"Saving model into {fname}-final {torch.sum(m.embedding().data)} {unwrap(m.scale())}"
        )
        torch.save(m, fname)
Exemple #2
0
def test_measurement(zero_debias=True):
    dtype = torch.FloatTensor
    w = Variable(torch.ones(n_dim, 1).type(dtype), requires_grad=True)
    b = Variable(torch.ones(1).type(dtype), requires_grad=True)
    x = Variable(torch.ones(1, n_dim).type(dtype), requires_grad=False)
    opt = YFOptimizer([w, b], lr=1.0, mu=0.0, zero_debias=zero_debias)

    target_h_max = 0.0
    target_h_min = 0.0
    g_norm_squared_avg = 0.0
    g_norm_avg = 0.0
    g_avg = 0.0
    target_dist = 0.0
    for i in range(n_iter):
        opt.zero_grad()

        loss = (x.mm(w) + b).sum()
        loss.backward()
        w.grad.data = (i + 1) * torch.ones([
            n_dim,
        ]).type(dtype)
        b.grad.data = (i + 1) * torch.ones([
            1,
        ]).type(dtype)

        opt.step()

        res = [opt._h_max, opt._h_min, opt._grad_var, opt._dist_to_opt]

        g_norm_squared_avg = 0.999 * g_norm_squared_avg  \
            + 0.001 * np.sum(((i + 1) * np.ones([n_dim + 1, ]))**2)
        g_norm_avg = 0.999 * g_norm_avg  \
            + 0.001 * np.linalg.norm((i + 1) * np.ones([n_dim + 1, ]))
        g_avg = 0.999 * g_avg + 0.001 * (i + 1)

        target_h_max = 0.999 * target_h_max + 0.001 * (i + 1)**2 * (n_dim + 1)
        target_h_min = 0.999 * target_h_min + 0.001 * \
            max(1, i + 2 - 20)**2 * (n_dim + 1)
        if zero_debias:
            target_var = g_norm_squared_avg / (1 - 0.999**(i + 1)) \
                - g_avg**2 * (n_dim + 1) / (1 - 0.999**(i + 1))**2
        else:
            target_var = g_norm_squared_avg - g_avg**2 * (n_dim + 1)
        target_dist = 0.999 * target_dist + 0.001 * g_norm_avg / g_norm_squared_avg

        if i == 0:
            continue
        if zero_debias:
            # print "iter ", i, " h max ", res[0], target_h_max/(1-0.999**(i + 1) ), \
            #   " h min ", res[1], target_h_min/(1-0.999**(i + 1) ), \
            #   " var ", res[2], target_var, \
            #   " dist ", res[3], target_dist/(1-0.999**(i + 1) )
            assert np.abs(target_h_max / (1 - 0.999**(i + 1)) -
                          res[0]) < np.abs(res[0]) * 1e-3
            assert np.abs(target_h_min / (1 - 0.999**(i + 1)) -
                          res[1]) < np.abs(res[1]) * 1e-3
            assert np.abs(target_var - res[2]) < np.abs(target_var) * 1e-3
            assert np.abs(target_dist / (1 - 0.999**(i + 1)) -
                          res[3]) < np.abs(res[3]) * 1e-3
        else:
            # print "iter ", i, " h max ", res[0], target_h_max, " h min ", res[1], target_h_min, \
            # " var ", res[2], target_var, " dist ", res[3], target_dist
            assert np.abs(target_h_max - res[0]) < np.abs(target_h_max) * 1e-3
            assert np.abs(target_h_min - res[1]) < np.abs(target_h_min) * 1e-3
            assert np.abs(target_var - res[2]) < np.abs(res[2]) * 1e-3
            assert np.abs(target_dist - res[3]) < np.abs(res[3]) * 1e-3
    print "sync measurement test passed!"
Exemple #3
0
def test_lr_mu(zero_debias=False):
    dtype = torch.FloatTensor
    w = Variable(torch.ones(n_dim, 1).type(dtype), requires_grad=True)
    b = Variable(torch.ones(1).type(dtype), requires_grad=True)
    x = Variable(torch.ones(1, n_dim).type(dtype), requires_grad=False)
    opt = YFOptimizer([w, b], lr=1.0, mu=0.0, zero_debias=zero_debias)

    target_h_max = 0.0
    target_h_min = 0.0
    g_norm_squared_avg = 0.0
    g_norm_avg = 0.0
    g_avg = 0.0
    target_dist = 0.0
    target_lr = 1.0
    target_mu = 0.0
    for i in range(n_iter):
        opt.zero_grad()

        loss = (x.mm(w) + b).sum()
        loss.backward()
        w.grad.data = (i + 1) * torch.ones([
            n_dim,
        ]).type(dtype)
        b.grad.data = (i + 1) * torch.ones([
            1,
        ]).type(dtype)

        opt.step()
        res = [
            opt._h_max, opt._h_min, opt._grad_var, opt._dist_to_opt, opt._lr,
            opt._mu
        ]

        g_norm_squared_avg = 0.999 * g_norm_squared_avg  \
            + 0.001 * np.sum(((i + 1) * np.ones([n_dim + 1, ]))**2)
        g_norm_avg = 0.999 * g_norm_avg  \
            + 0.001 * np.linalg.norm((i + 1) * np.ones([n_dim + 1, ]))
        g_avg = 0.999 * g_avg + 0.001 * (i + 1)

        target_h_max = 0.999 * target_h_max + 0.001 * (i + 1)**2 * (n_dim + 1)
        target_h_min = 0.999 * target_h_min + 0.001 * \
            max(1, i + 2 - 20)**2 * (n_dim + 1)
        if zero_debias:
            target_var = g_norm_squared_avg / (1 - 0.999**(i + 1)) \
                - g_avg**2 * (n_dim + 1) / (1 - 0.999**(i + 1))**2
        else:
            target_var = g_norm_squared_avg - g_avg**2 * (n_dim + 1)
        target_dist = 0.999 * target_dist + 0.001 * g_norm_avg / g_norm_squared_avg

        if i == 0:
            continue
        if zero_debias:
            # print "iter ", i, " h max ", res[0], target_h_max/(1-0.999**(i + 1) ), \
            #   " h min ", res[1], target_h_min/(1-0.999**(i + 1) ), \
            #   " var ", res[2], target_var, \
            #   " dist ", res[3], target_dist/(1-0.999**(i + 1) )
            assert np.abs(target_h_max / (1 - 0.999**(i + 1)) -
                          res[0]) < np.abs(res[0]) * 1e-3
            assert np.abs(target_h_min / (1 - 0.999**(i + 1)) -
                          res[1]) < np.abs(res[1]) * 1e-3
            assert np.abs(target_var - res[2]) < np.abs(target_var) * 1e-3
            assert np.abs(target_dist / (1 - 0.999**(i + 1)) -
                          res[3]) < np.abs(res[3]) * 1e-3
        else:
            # print "iter ", i, " h max ", res[0], target_h_max, " h min ", res[1], target_h_min, \
            # " var ", res[2], target_var, " dist ", res[3], target_dist
            assert np.abs(target_h_max - res[0]) < np.abs(target_h_max) * 1e-3
            assert np.abs(target_h_min - res[1]) < np.abs(target_h_min) * 1e-3
            assert np.abs(target_var - res[2]) < np.abs(res[2]) * 1e-3
            assert np.abs(target_dist - res[3]) < np.abs(res[3]) * 1e-3

        if i > 0:
            if zero_debias:
                lr, mu = tune_everything(
                    (target_dist / (1 - 0.999**(i + 1)))**2, target_var, 1,
                    target_h_min / (1 - 0.999**(i + 1)),
                    target_h_max / (1 - 0.999**(i + 1)))
            else:
                lr, mu = tune_everything(target_dist**2, target_var, 1,
                                         target_h_min, target_h_max)
            lr = np.real(lr)
            mu = np.real(mu)
            target_lr = 0.999 * target_lr + 0.001 * lr
            target_mu = 0.999 * target_mu + 0.001 * mu
            # print "lr ", target_lr, res[4], " mu ", target_mu, res[5]
            assert target_lr == 0.0 or np.abs(target_lr -
                                              res[4]) < np.abs(res[4]) * 1e-3
            assert target_mu == 0.0 or np.abs(target_mu -
                                              res[5]) < np.abs(res[5]) * 5e-3
    print "lr and mu computing test passed!"
        for param in optimizer._optimizer.param_groups[0]['params']:
            param.grad = target_norm * param.grad / grad_norm

        # You can enable this to see some slow reaction from our estimators
        # when the zero gradients start
        if True and (t > 6490 or t < 20):
            print(t, loss.data.item())
            print('Curvatures', optimizer._h_max, optimizer._h_min)
            print('mu_t, lr_t', optimizer._mu_t, optimizer._lr_t)
            print
    else:
        print(t, loss.data.item())

    # Calling the step function on an Optimizer makes an update to its parameters
    optimizer.step()

# grad_norm = torch_list_grad_norm(optimizer._optimizer.param_groups[0]['params'])

# 4/(np.sqrt(optimizer._h_max)+np.sqrt(optimizer._h_min))**2

# 4/(2*np.sqrt(grad_norm**2))**2

# optimizer._lr_t

# # In[14]:

# optimizer._lr

# # In[ ]:
def learn(dataset,
          rank=2,
          scale=1.,
          learning_rate=1e-1,
          tol=1e-8,
          epochs=100,
          use_yellowfin=False,
          use_adagrad=False,
          print_freq=1,
          model_save_file=None,
          model_load_file=None,
          batch_size=16,
          num_workers=None,
          lazy_generation=False,
          log_name=None,
          warm_start=None,
          learn_scale=False,
          checkpoint_freq=1000,
          sample=1.,
          subsample=None,
          exponential_rescale=None,
          extra_steps=1,
          use_svrg=False,
          T=10,
          use_hmds=False):
    # Log configuration
    formatter = logging.Formatter('%(asctime)s %(message)s')
    logging.basicConfig(
        level=logging.DEBUG,
        format='%(asctime)s %(message)s',
        datefmt='%FT%T',
    )
    if log_name is not None:
        logging.info(f"Logging to {log_name}")
        log = logging.getLogger()
        fh = logging.FileHandler(log_name)
        fh.setFormatter(formatter)
        log.addHandler(fh)

    logging.info(f"Commandline {sys.argv}")
    if model_save_file is None: logging.warn("No Model Save selected!")
    G = load_graph.load_graph(dataset)
    GM = nx.to_scipy_sparse_matrix(G)

    # grab scale if warm starting:
    if warm_start:
        scale = pandas.read_csv(warm_start, index_col=0).as_matrix()[0, -1]

    n = G.order()
    logging.info(f"Loaded Graph {dataset} with {n} nodes scale={scale}")

    Z = None

    def collate(ls):
        x, y = zip(*ls)
        return torch.cat(x), torch.cat(y)

    if lazy_generation:
        if subsample is not None:
            z = DataLoader(GraphRowSubSampler(G, scale, subsample),
                           batch_size,
                           shuffle=True,
                           collate_fn=collate)
        else:
            z = DataLoader(GraphRowSampler(G, scale),
                           batch_size,
                           shuffle=True,
                           collate_fn=collate)
        logging.info("Built Data Sampler")
    else:
        Z = gh.build_distance(G,
                              scale,
                              num_workers=int(num_workers) if num_workers
                              is not None else 16)  # load the whole matrix
        logging.info(f"Built distance matrix with {scale} factor")

        if subsample is not None:
            z = DataLoader(GraphRowSubSampler(G, scale, subsample, Z=Z),
                           batch_size,
                           shuffle=True,
                           collate_fn=collate)
        else:
            idx = torch.LongTensor([(i, j) for i in range(n)
                                    for j in range(i + 1, n)])
            Z_sampled = gh.dist_sample_rebuild_pos_neg(
                Z, sample) if sample < 1 else Z
            vals = torch.DoubleTensor(
                [Z_sampled[i, j] for i in range(n) for j in range(i + 1, n)])
            z = DataLoader(TensorDataset(idx, vals),
                           batch_size=batch_size,
                           shuffle=True,
                           pin_memory=torch.cuda.is_available())
        logging.info("Built data loader")

    if model_load_file is not None:
        logging.info(f"Loading {model_load_file}...")
        m = cudaify(torch.load(model_load_file))
        logging.info(
            f"Loaded scale {m.scale.data[0]} {torch.sum(m.w.data)} {m.epoch}")
    else:
        logging.info(f"Creating a fresh model warm_start?={warm_start}")

        m_init = None
        if warm_start:
            # load from DataFrame; assume that the julia combinatorial embedding has been saved
            ws_data = pandas.read_csv(warm_start, index_col=0).as_matrix()
            scale = ws_data[0, ws_data.shape[1] - 1]
            m_init = torch.DoubleTensor(ws_data[:,
                                                range(ws_data.shape[1] - 1)])
        elif use_hmds:
            # m_init = torch.DoubleTensor(mds_warmstart.get_normalized_hyperbolic(mds_warmstart.get_model(dataset,rank,scale)[1]))
            m_init = torch.DoubleTensor(
                mds_warmstart.get_model(dataset, rank, scale)[1])

        logging.info(
            f"\t Warmstarting? {warm_start} {m_init.size() if warm_start else None} {G.order()}"
        )
        m = cudaify(
            Hyperbolic_Emb(G.order(),
                           rank,
                           initialize=m_init,
                           learn_scale=learn_scale,
                           exponential_rescale=exponential_rescale))
        m.normalize()
        m.epoch = 0
    logging.info(
        f"Constructed model with rank={rank} and epochs={m.epoch} isnan={np.any(np.isnan(m.w.cpu().data.numpy()))}"
    )

    #
    # Build the Optimizer
    #
    # TODO: Redo this in a sensible way!!
    #
    opt = torch.optim.SGD(m.parameters(), lr=learning_rate)
    if use_yellowfin:
        from yellowfin import YFOptimizer
        opt = YFOptimizer(m.parameters())

    if use_adagrad:
        opt = torch.optim.Adagrad(m.parameters())

    if use_svrg:
        from svrg import SVRG
        base_opt = torch.optim.Adagrad if use_adagrad else torch.optim.SGD
        opt = SVRG(m.parameters(),
                   lr=learning_rate,
                   T=T,
                   data_loader=z,
                   opt=base_opt)

    logging.info(opt)

    # Log stats from import: when warmstarting, check that it matches Julia's stats
    logging.info(f"*** Initial Checkpoint. Computing Stats")
    major_stats(GM, 1 + m.scale.data[0], n, m, lazy_generation, Z, z)
    logging.info("*** End Initial Checkpoint\n")

    for i in range(m.epoch, m.epoch + epochs):
        l = 0.0
        m.train(True)
        if use_svrg:
            for data in z:

                def closure(data=data, target=None):
                    _data = data if target is None else (data, target)
                    c = m.loss(cu_var(_data))
                    c.backward()
                    return c.data[0]

                l += opt.step(closure)

                # Projection
                m.normalize()

        else:
            opt.zero_grad()  # This is handled by the SVRG.
            for the_step in range(extra_steps):
                # Accumulate the gradient
                for u in z:
                    _loss = m.loss(cu_var(u, requires_grad=False))
                    _loss.backward()
                    l += _loss.data[0]
                Hyperbolic_Parameter.correct_metric(
                    m.parameters())  # NB: THIS IS THE NEW CALL
                # print("Scale before step: ", m.scale.data)
                opt.step()
                # print("Scale after step: ", m.scale.data)
                # Projection
                m.normalize()

                #l += step(m, opt, u).data[0]

        # Logging code
        if l < tol:
            logging.info("Found a {l} solution. Done at iteration {i}!")
            break
        if i % print_freq == 0:
            logging.info(f"{i} loss={l}")
        if i % checkpoint_freq == 0:
            logging.info(f"\n*** Major Checkpoint. Computing Stats and Saving")
            major_stats(GM, 1 + m.scale.data[0], n, m, True, Z, z)
            if model_save_file is not None:
                fname = f"{model_save_file}.{m.epoch}"
                logging.info(
                    f"Saving model into {fname} {torch.sum(m.w.data)} ")
                torch.save(m, fname)
            logging.info("*** End Major Checkpoint\n")
        m.epoch += 1

    logging.info(f"final loss={l}")

    if model_save_file is not None:
        fname = f"{model_save_file}.final"
        logging.info(
            f"Saving model into {fname}-final {torch.sum(m.w.data)} {m.scale.data[0]}"
        )
        torch.save(m, fname)

    major_stats(GM, 1 + m.scale.data[0], n, m, lazy_generation, Z, z)
        x = numpy.asarray(x, dtype=numpy.float32)
        x = torch.from_numpy(x)
        x = x.view(x.size()[0], x.size()[1], input_size)
        y = torch.cat((x[:, 1:, :], torch.zeros([x.size()[0], 1, input_size])),
                      1)
        images = Variable(x).cuda()
        labels = Variable(y).long().cuda()
        opt.zero_grad()
        outputs = rnn(images)
        shp = outputs.size()
        outputs_reshp = outputs.view([shp[0] * shp[1], num_classes])
        labels_reshp = labels.view(shp[0] * shp[1])
        loss = criterion(outputs_reshp, labels_reshp)
        loss.backward()
        opt.step()

        # For plotting
        #unclip_g_norm_list.append(opt._unclip_grad_norm)

        t += time.time()

        loss_list.append(loss.data[0])

        local_curv_list.append(opt._global_state['grad_norm_squared'])
        max_curv_list.append(opt._h_max)
        min_curv_list.append(opt._h_min)
        lr_list.append(opt._lr)
        mu_list.append(opt._mu)
        dr_list.append((opt._h_max + 1e-6) / (opt._h_min + 1e-6))
        dist_list.append(opt._dist_to_opt)