def get_kcenter_test_loss(model,
                           adj,
                           bin_adj,
                           train_objectives,
                           test_objectives,
                           instances,
                           features,
                           num_reps=10,
                           hardmax=False,
                           update=False,
                           algoname=None):
     loss = 0
     for i in instances:
         best_loss = 100
         x_best = None
         for _ in range(num_reps):
             x = model(features[i], adj[i])
             if x.sum() > args.K:
                 x = args.K * x / x.sum()
             train_loss = train_objectives[i](x)
             if train_loss.item() < best_loss:
                 best_loss = train_loss.item()
                 x_best = x
         testvals = []
         trainvals = []
         for _ in range(50):
             y = rounding(x_best)
             testvals.append(test_objectives[i](y).item())
             trainvals.append(train_objectives[i](y).item())
         loss += testvals[np.argmin(trainvals)]
         if update:
             vals[algoname][test_instances.index(i)] = testvals[
                 np.argmin(trainvals)]
     return loss / (len(instances))
 def get_kcenter_test_loss(model,
                           adj,
                           bin_adj,
                           train_objectives,
                           test_objectives,
                           instances,
                           features,
                           num_reps=10,
                           hardmax=False,
                           update=False,
                           algoname=None):
     loss = 0
     for idx, i in enumerate(instances):
         best_loss = 100
         x_best = None
         for _ in range(num_reps):
             mu, r, embeds, dist = model(features[i], adj[i],
                                         num_cluster_iter)
             x = torch.softmax(dist * args.kcentertemp, 0).sum(dim=1)
             x = 2 * (torch.sigmoid(4 * x) - 0.5)
             if x.sum() > args.K:
                 x = args.K * x / x.sum()
             train_loss = loss_fn(mu, r, embeds, dist, bin_adj[i],
                                  train_objectives[i], args)
             if train_loss.item() < best_loss:
                 best_loss = train_loss.item()
                 x_best = x
         testvals = []
         trainvals = []
         for _ in range(50):
             y = rounding(x_best)
             testvals.append(test_objectives[i](y).item())
             trainvals.append(train_objectives[i](y).item())
         loss += testvals[np.argmin(trainvals)]
         if update:
             vals[algoname][test_instances.index(i)] = testvals[np.argmin(
                 trainvals)]
     return loss / (len(instances))
Ejemplo n.º 3
0
                curr_test_loss = loss_test.item()
                #convert distances into a feasible (fractional x)
                x_best = torch.softmax(dist * args.kcentertemp, 0).sum(dim=1)
                x_best = 2 * (torch.sigmoid(4 * x_best) - 0.5)
                if x_best.sum() > K:
                    x_best = K * x_best / x_best.sum()
        losses.append(loss.item())
        optimizer.step()

    #for k-center: round 50 times and take the solution with best training
    #value
    if args.objective == 'kcenter':
        testvals = []
        trainvals = []
        for _ in range(50):
            y = rounding(x_best)
            testvals.append(obj_test(y).item())
            trainvals.append(obj_train(y).item())
        print('ClusterNet value', testvals[np.argmin(trainvals)])
    if args.objective == 'modularity':
        print('ClusterNet value', curr_test_loss)

##############################################################################
#TRAIN TWO-STAGE
##############################################################################


def train_twostage(model_ts):
    optimizer_ts = optim.Adam(model_ts.parameters(),
                              lr=args.lr,
                              weight_decay=args.weight_decay)