예제 #1
0
                                 device,
                                 nz,
                                 R=R,
                                 num_random_samples=num_random_samples,
                                 seed=manual_seed)
q, C, U_init = q.to(device), C.to(device), U_init.to(device)

G_generator = models.Generator(image_size, nc, k=nz, ngf=64)
D_embedding = models.Embedding(
    image_size,
    nc,
    reg,
    device,
    q,
    C,
    U_init,
    k=hidden_dim,
    num_random_samples=num_random_samples,
    R=R,
    seed=manual_seed,
    ndf=64,
    random=random_,
)

netG = models.NetG(G_generator)
path_model_G = "netG_celebA_600_1.pth"
netG.load_state_dict(torch.load(
    path_model_G, map_location="cpu"))  # If on cluster comment map_location
netG.to(device)

netE = models.NetE(D_embedding)
def training_func(
    num_random_samples,
    reg,
    batch_size,
    niter_sin,
    image_size,
    nc,
    nz,
    dataset_name,
    device,
    manual_seed,
    lr,
    max_iter,
    data_root,
    R,
    random_,
):

    name_dir = ("sampled_images_cifar_max" + "_" + str(num_random_samples) +
                "_" + str(reg))
    if os.path.exists(name_dir) == 0:
        os.mkdir(name_dir)

    epsilon = reg
    hidden_dim = nz

    # Create an output file
    file_to_print = open(
        "results_training_cifar_max" + "_" + str(num_random_samples) + "_" +
        str(reg) + ".csv",
        "w",
    )
    file_to_print.write(str(device) + "\n")
    file_to_print.flush()

    # Fix the seed
    np.random.seed(seed=manual_seed)
    random.seed(manual_seed)
    torch.manual_seed(manual_seed)
    torch.cuda.manual_seed(manual_seed)
    cudnn.benchmark = True

    # Initialisation of weights
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find("Conv") != -1:
            m.weight.data.normal_(0.0, 0.02)
        elif classname.find("BatchNorm") != -1:
            m.weight.data.normal_(1.0, 0.02)
            m.bias.data.fill_(0)
        elif classname.find("Linear") != -1:
            m.weight.data.normal_(0.0, 0.1)
            m.bias.data.fill_(0)

    trn_dataset = data_loading.get_data(image_size,
                                        dataset_name,
                                        data_root,
                                        train_flag=True)
    trn_loader = torch.utils.data.DataLoader(trn_dataset,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             num_workers=1)

    # construct Generator and Embedding:
    q, C, U_init = compute_constants(reg,
                                     device,
                                     nz,
                                     R=R,
                                     num_random_samples=num_random_samples,
                                     seed=manual_seed)

    G_generator = models.Generator(image_size, nc, k=nz, ngf=64)
    D_embedding = models.Embedding(
        image_size,
        nc,
        reg,
        device,
        q,
        C,
        U_init,
        k=hidden_dim,
        num_random_samples=num_random_samples,
        R=R,
        seed=manual_seed,
        ndf=64,
        random=random_,
    )

    netG = models.NetG(G_generator)
    netE = models.NetE(D_embedding)

    netG.apply(weights_init)
    netE.apply(weights_init)

    netG.to(device)
    netE.to(device)

    lin_Sinkhorn_AD = torch_lin_sinkhorn.Lin_Sinkhorn_AD.apply
    fixed_noise = torch.DoubleTensor(64, nz, 1, 1).normal_(0, 1).to(device)
    one = torch.tensor(1, dtype=torch.float).double()
    mone = one * -1

    # setup optimizer
    optimizerG = torch.optim.RMSprop(netG.parameters(), lr=lr)
    optimizerE = torch.optim.RMSprop(netE.parameters(), lr=lr)

    time = timeit.default_timer()
    gen_iterations = 0

    for t in range(max_iter):
        data_iter = iter(trn_loader)
        i = 0
        while i < len(trn_loader):
            # ---------------------------
            #        Optimize over NetE
            # ---------------------------
            for p in netE.parameters():
                p.requires_grad = True

            if gen_iterations < 25 or gen_iterations % 500 == 0:
                Diters = 10  # 10
                Giters = 1
            else:
                Diters = 1  # 5
                Giters = 1

            for j in range(Diters):
                if i == len(trn_loader):
                    break

                for p in netE.parameters():
                    p.data.clamp_(-0.01,
                                  0.01)  # clamp parameters of NetE to a cube

                data = data_iter.next()
                i += 1
                netE.zero_grad()

                x_cpu, _ = data
                x = x_cpu.to(device)
                x_emb = netE(x)

                noise = torch.FloatTensor(batch_size, nz, 1,
                                          1).normal_(0, 1).to(device)
                with torch.no_grad():
                    y = netG(noise)

                y_emb = netE(y)

                ### Compute the loss ###
                sink_E = (
                    2 * lin_Sinkhorn_AD(x_emb, y_emb, epsilon, niter_sin) -
                    lin_Sinkhorn_AD(y_emb, y_emb, epsilon, niter_sin) -
                    lin_Sinkhorn_AD(x_emb, x_emb, epsilon, niter_sin))

                sink_E.backward(mone)
                optimizerE.step()

            # ---------------------------
            #        Optimize over NetG
            # ---------------------------
            for p in netE.parameters():
                p.requires_grad = False

            for j in range(Giters):
                if i == len(trn_loader):
                    break

                data = data_iter.next()
                i += 1
                netG.zero_grad()

                x_cpu, _ = data
                x = x_cpu.to(device)
                x_emb = netE(x)

                noise = torch.FloatTensor(batch_size, nz, 1,
                                          1).normal_(0, 1).to(device)
                y = netG(noise)
                y_emb = netE(y)

                # Compute the loss
                sink_G = (
                    2 * lin_Sinkhorn_AD(x_emb, y_emb, epsilon, niter_sin) -
                    lin_Sinkhorn_AD(y_emb, y_emb, epsilon, niter_sin) -
                    lin_Sinkhorn_AD(x_emb, x_emb, epsilon, niter_sin))

                sink_G.backward(one)
                optimizerG.step()

                gen_iterations += 1

            run_time = (timeit.default_timer() - time) / 60.0

            s = "[%3d / %3d] [%3d / %3d] [%5d] (%.2f m) loss_E: %.6f loss_G: %.6f" % (
                t,
                max_iter,
                i * batch_size,
                batch_size * len(trn_loader),
                gen_iterations,
                run_time,
                sink_E,
                sink_G,
            )

            s = s + "\n"
            file_to_print.write(s)
            file_to_print.flush()

            if gen_iterations % 100 == 0:
                with torch.no_grad():
                    fixed_noise = fixed_noise.float()
                    y_fixed = netG(fixed_noise)
                    y_fixed = y_fixed.mul(0.5).add(0.5)
                    vutils.save_image(
                        y_fixed,
                        "{0}/fake_samples_{1}.png".format(
                            name_dir, gen_iterations),
                    )

        if t % 10 == 0:
            torch.save(
                netG.state_dict(),
                "netG_cifar_max" + "_" + str(num_random_samples) + "_" +
                str(reg) + ".pth",
            )
            torch.save(
                netE.state_dict(),
                "netE_cifar_max" + "_" + str(num_random_samples) + "_" +
                str(reg) + ".pth",
            )
예제 #3
0
def main(args):
    #     print(args)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Dataset processing

    input_dir = "/global/cscratch1/sd/danieltm/ExaTrkX/trackml/train_all/"
    all_events = os.listdir(input_dir)
    all_events = [input_dir + event[:14] for event in all_events]
    np.random.shuffle(all_events)

    train_dataset = [
        prepare_event(event_file, args.pt_cut, [1000, np.pi, 1000],
                      args.adjacent)
        for event_file in all_events[:args.train_size]
    ]
    test_dataset = [
        prepare_event(event_file, args.pt_cut, [1000, np.pi, 1000],
                      args.adjacent)
        for event_file in all_events[-args.val_size:]
    ]
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)

    # Model config
    e_configs = {
        "in_channels": 3,
        "emb_hidden": args.emb_hidden,
        "nb_layer": args.nb_layer,
        "emb_dim": args.emb_dim,
    }
    m_configs = {
        "in_channels": 3,
        "emb_hidden": args.emb_hidden,
        "nb_layer": args.nb_layer,
        "emb_dim": args.emb_dim,
        "r": args.r_val,
        "hidden_dim": args.hidden_dim,
        "n_graph_iters": args.n_graph_iters,
    }
    other_configs = {
        "weight": args.weight,
        "r_train": args.r_train,
        "r_val": args.r_val,
        "margin": args.margin,
        "reduction": "mean",
    }

    # Create and pretrain embedding
    embedding_model = models.Embedding(**e_configs).to(device)
    wandb.init(group="EmbeddingToAGNN_PurTimesEff", config=m_configs)
    embedding_optimizer = torch.optim.Adam(embedding_model.parameters(),
                                           lr=0.0005,
                                           weight_decay=1e-3,
                                           amsgrad=True)

    for epoch in range(args.pretrain_epochs):
        tic = tt()
        embedding_model.train()
        cluster_pur, train_loss = train_emb(embedding_model, train_loader,
                                            embedding_optimizer, other_configs)

        embedding_model.eval()
        with torch.no_grad():
            cluster_pur, cluster_eff, val_loss, av_nhood_size = evaluate_emb(
                embedding_model, test_loader, other_configs)
        wandb.log({
            "val_loss": val_loss,
            "train_loss": train_loss,
            "cluster_pur": cluster_pur,
            "cluster_eff": cluster_eff,
            "av_nhood_size": av_nhood_size,
        })

    # Create and train main model
    model = getattr(models,
                    args.model)(**m_configs,
                                pretrained_model=embedding_model).to(device)
    multi_loss = models.MultiNoiseLoss(n_losses=2).to(device)
    m_configs.update(other_configs)
    wandb.run.save()
    print(wandb.run.name)
    model_name = wandb.run.name
    wandb.watch(model, log="all")

    # Optimizer config

    optimizer = torch.optim.AdamW(
        [
            {
                "params":
                chain(model.emb_network_1.parameters(),
                      model.emb_network_2.parameters())
            },
            {
                "params":
                chain(
                    model.node_network.parameters(),
                    model.edge_network.parameters(),
                    model.input_feature_network.parameters(),
                )
            },
            {
                "params": multi_loss.noise_params
            },
        ],
        lr=0.001,
        weight_decay=1e-3,
        amsgrad=True,
    )

    # Scheduler config

    lambda1 = lambda ep: 1 / (args.lr_1**(ep // 10))
    lambda2 = lambda ep: 1 / (args.lr_2**(ep // 30))
    lambda3 = lambda ep: 1 / (args.lr_3**(ep // 10))
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer, lr_lambda=[lambda1, lambda2, lambda3])

    # Training loop

    for epoch in range(50):
        tic = tt()
        model.train()
        if args.adjacent:
            edge_acc, cluster_pur, train_loss = balanced_adjacent_train(
                model, train_loader, optimizer, multi_loss, m_configs)
        else:
            edge_acc, cluster_pur, train_loss = balanced_train(
                model, train_loader, optimizer, multi_loss, m_configs)
        #         print("Training loss:", train_loss)

        model.eval()
        if args.adjacent:
            with torch.no_grad():
                (
                    edge_acc,
                    edge_pur,
                    edge_eff,
                    cluster_pur,
                    cluster_eff,
                    val_loss,
                    av_nhood_size,
                ) = evaluate_adjacent(model, test_loader, multi_loss,
                                      m_configs)
        else:
            with torch.no_grad():
                (
                    edge_acc,
                    edge_pur,
                    edge_eff,
                    cluster_pur,
                    cluster_eff,
                    val_loss,
                    av_nhood_size,
                ) = evaluate(model, test_loader, multi_loss, m_configs)
        scheduler.step()
        wandb.log({
            "val_loss": val_loss,
            "train_loss": train_loss,
            "edge_acc": edge_acc,
            "edge_pur": edge_pur,
            "edge_eff": edge_eff,
            "cluster_pur": cluster_pur,
            "cluster_eff": cluster_eff,
            "lr": scheduler._last_lr[0],
            "combined_performance":
            edge_eff * cluster_eff * edge_pur + cluster_pur,
            "combined_efficiency": edge_eff * cluster_eff * edge_pur,
            "noise_1": multi_loss.noise_params[0].item(),
            "noise_2": multi_loss.noise_params[1].item(),
            "av_nhood_size": av_nhood_size,
        })

        save_model(
            epoch,
            model,
            optimizer,
            scheduler,
            cluster_eff,
            m_configs,
            "EmbeddingToAGNN/" + model_name + ".tar",
        )
예제 #4
0
#raw_test = fetch_20newsgroups(subset='test', categories=categories, data_home='../../..')
raw_train = fetch_20newsgroups(subset='train', data_home="./")
raw_test = fetch_20newsgroups(subset='test', data_home="./")

#print(len(raw_train.data),len(raw_test.data),Const.OUTPUT)

train_set = utils.make_dataset(raw_train)
test_data = utils.make_dataset(raw_test)

vocab = Vocabulary(min_freq=10).from_dataset(train_set, field_name='words')
vocab.index_dataset(train_set, field_name='words', new_field_name='words')
vocab.index_dataset(test_data, field_name='words', new_field_name='words')
train_data, dev_data = train_set.split(0.1)
print(len(train_data), len(dev_data), len(test_data), len(vocab))

embed = models.Embedding(len(vocab), options.embed_dim)
if options.model == "RNN":
    model = models.RNNText(embed,
                           options.hidden_size,
                           20,
                           dropout=options.dropout,
                           layers=options.layers)
elif options.model == "CNN":
    model = models.CNNText(embed,
                           20,
                           dropout=options.dropout,
                           padding=vocab.padding_idx)
elif options.model == "CNNKMAX":
    model = models.CNN_KMAX(embed,
                            20,
                            dropout=options.dropout,