コード例 #1
0
ファイル: main.py プロジェクト: sznajder/mnist_graph_gan
    #     sys.exit()
else:
    os.mkdir('./figs/' + name)
    os.mkdir('./losses/' + name)
    os.mkdir('./models/' + name)

del onlydirs

f = open("args/" + name + ".txt", "w+")
f.write(str(locals()))
f.close()

#Change to True !!
X = MNISTGraphDataset(num_hits,
                      train=TRAIN,
                      num=NUM,
                      intensities=INTENSITIES,
                      mnist8m=MNIST8M)
X_loaded = DataLoader(X, shuffle=True, batch_size=batch_size)

if (LOAD_MODEL):
    start_epoch = 255
    G = torch.load("models/" + name + "/G_" + str(start_epoch) + ".pt")
    D = torch.load("models/" + name + "/D_" + str(start_epoch) + ".pt")
else:
    start_epoch = 0
    G = Simple_GRU(node_size,
                   fe_out_size,
                   gru_hidden_size,
                   gru_num_layers,
                   num_iters,
コード例 #2
0
gru_hidden_size = 100
gru_num_layers = 3
dropout = 0.3
batch_size = 1024
num_thresholded = 100
gen_in_dim = 100
lr = 0.00005
lr_disc = 0.0001
lr_gen = 0.00005
num_critic = 1
weight_clipping_limit = 1

torch.manual_seed(4)

#Change to True !!
X = MNISTGraphDataset(num_thresholded, train=True)
X_loaded = DataLoader(X, shuffle=True, batch_size=batch_size)

name = "22_wgan"

if (LOAD_MODEL):
    start_epoch = 10
    G = torch.load("models/" + name + "_G_" + str(start_epoch) + ".pt")
    D = torch.load("models/" + name + "_D_" + str(start_epoch) + ".pt")
else:
    start_epoch = 0
    G = Simple_GRU(input_size, output_size, gen_in_dim, gru_hidden_size,
                   gru_num_layers, dropout, batch_size).cuda()
    D = Critic((num_thresholded, input_size), dropout, batch_size,
               wgan=True).cuda()
コード例 #3
0
ファイル: main.py プロジェクト: sznajder/mnist_graph_gan
def main(args):
    args = init(args)

    def pf(data):
        return data.y == args.num

    pre_filter = pf if args.num != -1 else None

    print("loading data")

    if (args.sparse_mnist):
        X = MNISTGraphDataset(args.dataset_path,
                              args.num_hits,
                              train=args.train,
                              num=args.num)
        X_loaded = DataLoader(X,
                              shuffle=True,
                              batch_size=args.batch_size,
                              pin_memory=True)
    else:
        if (args.gcnn):
            X = MNISTSuperpixels(args.dir_path,
                                 train=args.train,
                                 pre_transform=T.Cartesian(),
                                 pre_filter=pre_filter)
            X_loaded = tgDataLoader(X,
                                    shuffle=True,
                                    batch_size=args.batch_size)
        else:
            X = SuperpixelsDataset(args.dataset_path,
                                   args.num_hits,
                                   train=args.train,
                                   num=args.num)
            X_loaded = DataLoader(X,
                                  shuffle=True,
                                  batch_size=args.batch_size,
                                  pin_memory=True)

    print("loaded data")

    # model

    if (args.load_model):
        G = torch.load(args.model_path + args.name + "/G_" +
                       str(args.start_epoch) + ".pt",
                       map_location=args.device)
        D = torch.load(args.model_path + args.name + "/D_" +
                       str(args.start_epoch) + ".pt",
                       map_location=args.device)

    else:
        # G = Graph_Generator(args.node_feat_size, args.fe_hidden_size, args.fe_out_size, args.fn_hidden_size, args.fn_num_layers, args.mp_iters_gen, args.num_hits, args.gen_dropout, args.leaky_relu_alpha, hidden_node_size=args.hidden_node_size, int_diffs=args.int_diffs, pos_diffs=args.pos_diffs, gru=args.gru, batch_norm=args.batch_norm, device=device).to(args.device)
        if (args.gcnn):
            G = GaussianGenerator(args=deepcopy(args)).to(args.device)
            D = MoNet(args=deepcopy(args)).to(args.device)
            # D = Gaussian_Discriminator(args.node_feat_size, args.fe_hidden_size, args.fe_out_size, args.mp_hidden_size, args.mp_num_layers, args.num_iters, args.num_hits, args.dropout, args.leaky_relu_alpha, kernel_size=args.kernel_size, hidden_node_size=args.hidden_node_size, int_diffs=args.int_diffs, gru=GRU, batch_norm=args.batch_norm, device=device).to(args.device)
        else:
            # D = Graph_Discriminator(args.node_feat_size, args.fe_hidden_size, args.fe_out_size, args.fn_hidden_size, args.fn_num_layers, args.mp_iters_disc, args.num_hits, args.disc_dropout, args.leaky_relu_alpha, hidden_node_size=args.hidden_node_size, wgan=args.wgan, int_diffs=args.int_diffs, pos_diffs=args.pos_diffs, gru=args.gru, batch_norm=args.batch_norm, device=device).to(args.device)
            print("Generator")
            G = Graph_GAN(gen=True, args=deepcopy(args)).to(args.device)
            print("Discriminator")
            D = Graph_GAN(gen=False, args=deepcopy(args)).to(args.device)

    print("Models loaded")

    # optimizer

    if args.spectral_norm_gen:
        G_params = filter(lambda p: p.requires_grad, G.parameters())
    else:
        G_params = G.parameters()

    if args.spectral_norm_gen:
        D_params = filter(lambda p: p.requires_grad, D.parameters())
    else:
        D_params = D.parameters()

    if (args.optimizer == 'rmsprop'):
        G_optimizer = optim.RMSprop(G_params, lr=args.lr_gen)
        D_optimizer = optim.RMSprop(D_params, lr=args.lr_disc)
    elif (args.optimizer == 'adadelta'):
        G_optimizer = optim.Adadelta(G_params, lr=args.lr_gen)
        D_optimizer = optim.Adadelta(D_params, lr=args.lr_disc)
    elif (args.optimizer == 'acgd'):
        optimizer = ACGD(max_params=G_params,
                         min_params=D_params,
                         lr_max=args.lr_gen,
                         lr_min=args.lr_disc,
                         device=args.device)
    elif (args.optimizer == 'adam' or args.optimizer == 'None'):
        G_optimizer = optim.Adam(G_params,
                                 lr=args.lr_gen,
                                 weight_decay=5e-4,
                                 betas=(args.beta1, args.beta2))
        D_optimizer = optim.Adam(D_params,
                                 lr=args.lr_disc,
                                 weight_decay=5e-4,
                                 betas=(args.beta1, args.beta2))

    if (args.load_model):
        try:
            if (not args.optimizer == 'acgd'):
                G_optimizer.load_state_dict(
                    torch.load(args.model_path + args.name + "/G_optim_" +
                               str(args.start_epoch) + ".pt",
                               map_location=args.device))
                D_optimizer.load_state_dict(
                    torch.load(args.model_path + args.name + "/D_optim_" +
                               str(args.start_epoch) + ".pt",
                               map_location=args.device))
            else:
                optimizer.load_state_dict(
                    torch.load(args.model_path + args.name + "/optim_" +
                               str(args.start_epoch) + ".pt",
                               map_location=args.device))
        except:
            print("Error loading optimizer")

    print("optimizers loaded")

    if args.fid: C, mu2, sigma2 = evaluation.load(args, X_loaded)

    normal_dist = Normal(
        torch.tensor(0.).to(args.device),
        torch.tensor(args.sd).to(args.device))

    lns = args.latent_node_size if args.latent_node_size else args.hidden_node_size

    args.noise_file_name = "num_samples_" + str(
        args.num_samples) + "_num_nodes_" + str(
            args.num_hits) + "_latent_node_size_" + str(lns) + "_sd_" + str(
                args.sd) + ".pt"
    if args.gcnn: args.noise_file_name = "gcnn_" + args.noise_file_name

    noise_file_names = listdir(args.noise_path)

    if args.noise_file_name not in noise_file_names:
        if (args.gcnn):
            torch.save(
                normal_dist.sample(
                    (args.num_samples * 5, 2 + args.channels[0])),
                args.noise_path + args.noise_file_name)
        else:
            torch.save(
                normal_dist.sample((args.num_samples, args.num_hits, lns)),
                args.noise_path + args.noise_file_name)

    losses = {}

    if (args.load_model):
        try:
            losses['D'] = np.loadtxt(args.losses_path + args.name + "/" +
                                     "D.txt").tolist()[:args.start_epoch]
            losses['Dr'] = np.loadtxt(args.losses_path + args.name + "/" +
                                      "Dr.txt").tolist()[:args.start_epoch]
            losses['Df'] = np.loadtxt(args.losses_path + args.name + "/" +
                                      "Df.txt").tolist()[:args.start_epoch]
            losses['G'] = np.loadtxt(args.losses_path + args.name + "/" +
                                     "G.txt").tolist()[:args.start_epoch]
            if args.fid:
                losses['fid'] = np.loadtxt(
                    args.losses_path + args.name + "/" +
                    "fid.txt").tolist()[:args.start_epoch]
            if (args.gp):
                losses['gp'] = np.loadtxt(args.losses_path + args.name + "/" +
                                          "gp.txt").tolist()[:args.start_epoch]
        except:
            print("couldn't load losses")
            losses['D'] = []
            losses['Dr'] = []
            losses['Df'] = []
            losses['G'] = []
            if args.fid: losses['fid'] = []
            if (args.gp): losses['gp'] = []

    else:
        losses['D'] = []
        losses['Dr'] = []
        losses['Df'] = []
        losses['G'] = []
        if args.fid: losses['fid'] = []
        if (args.gp): losses['gp'] = []

    Y_real = torch.ones(args.batch_size, 1).to(args.device)
    Y_fake = torch.zeros(args.batch_size, 1).to(args.device)

    def train_D(data, gen_data=None, unrolled=False):
        if args.debug: print("dtrain")
        D.train()
        D_optimizer.zero_grad()

        run_batch_size = data.shape[0] if not args.gcnn else data.y.shape[0]

        if gen_data is None:
            gen_data = utils.gen(args, G, normal_dist, run_batch_size)
            if (args.gcnn):
                gen_data = utils.convert_to_batch(args, gen_data,
                                                  run_batch_size)

        if args.augment:
            p = args.aug_prob if not args.adaptive_prob else losses['p'][-1]
            data = augment.augment(args, data, p)
            gen_data = augment.augment(args, gen_data, p)

        D_real_output = D(data.clone())
        D_fake_output = D(gen_data)

        D_loss, D_loss_items = utils.calc_D_loss(args, D, data, gen_data,
                                                 D_real_output, D_fake_output,
                                                 run_batch_size, Y_real,
                                                 Y_fake)
        D_loss.backward(create_graph=unrolled)

        D_optimizer.step()
        return D_loss_items

    def train_G(data):
        if args.debug: print("gtrain")
        G.train()
        G_optimizer.zero_grad()

        gen_data = utils.gen(args, G, normal_dist, args.batch_size)
        if (args.gcnn):
            gen_data = utils.convert_to_batch(args, gen_data, args.batch_size)

        if args.augment:
            p = args.aug_prob if not args.adaptive_prob else losses['p'][-1]
            gen_data = augment.augment(args, gen_data, p)

        if (args.unrolled_steps > 0):
            D_backup = deepcopy(D)
            for i in range(args.unrolled_steps - 1):
                train_D(data, gen_data=gen_data, unrolled=True)

        D_fake_output = D(gen_data)

        G_loss = utils.calc_G_loss(args, D_fake_output, Y_real)

        G_loss.backward()
        G_optimizer.step()

        if (args.unrolled_steps > 0):
            D.load(D_backup)

        return G_loss.item()

    def train_acgd(data):
        if args.debug: print("acgd train")
        D.train()
        G.train()
        optimizer.zero_grad()

        run_batch_size = data.shape[0] if not args.gcnn else data.y.shape[0]

        gen_data = utils.gen(args, G, normal_dist, run_batch_size)
        if (args.gcnn):
            gen_data = utils.convert_to_batch(args, gen_data, run_batch_size)

        if args.augment:
            p = args.aug_prob if not args.adaptive_prob else losses['p'][-1]
            data = utils.rand_translate(args, data, p)
            gen_data = utils.rand_translate(args, gen_data, p)

        D_real_output = D(data.clone())
        D_fake_output = D(gen_data)

        D_loss, D_loss_items = utils.calc_D_loss(args, D, data, gen_data,
                                                 D_real_output, D_fake_output,
                                                 run_batch_size)

        optimizer.step(loss=D_loss)

        G.eval()
        with torch.no_grad():
            G_loss = utils.calc_G_loss(args, D_fake_output)

        return D_loss_items, G_loss.item()

    def train():
        k = 0
        temp_ng = args.num_gen
        if (args.fid):
            losses['fid'].append(
                evaluation.get_fid(args, C, G, normal_dist, mu2, sigma2))
        if (args.save_zero):
            save_outputs.save_sample_outputs(args, D, G, normal_dist,
                                             args.name, 0, losses)
        for i in range(args.start_epoch, args.num_epochs):
            print("Epoch %d %s" % ((i + 1), args.name))
            Dr_loss = 0
            Df_loss = 0
            G_loss = 0
            D_loss = 0
            gp_loss = 0
            lenX = len(X_loaded)
            for batch_ndx, data in tqdm(enumerate(X_loaded), total=lenX):
                data = data.to(args.device)
                if (args.gcnn):
                    data.pos = (data.pos - 14) / 28
                    row, col = data.edge_index
                    data.edge_attr = (data.pos[col] -
                                      data.pos[row]) / (2 * args.cutoff) + 0.5

                if (not args.optimizer == 'acgd'):
                    if (args.num_critic > 1):
                        D_loss_items = train_D(data)
                        D_loss += D_loss_items['D']
                        Dr_loss += D_loss_items['Dr']
                        Df_loss += D_loss_items['Df']
                        if (args.gp): gp_loss += D_loss_items['gp']

                        if ((batch_ndx - 1) % args.num_critic == 0):
                            G_loss += train_G(data)
                    else:
                        if (batch_ndx == 0
                                or (batch_ndx - 1) % args.num_gen == 0):
                            D_loss_items = train_D(data)
                            D_loss += D_loss_items['D']
                            Dr_loss += D_loss_items['Dr']
                            Df_loss += D_loss_items['Df']
                            if (args.gp): gp_loss += D_loss_items['gp']

                        G_loss += train_G(data)
                else:
                    D_loss_items, G_loss_item = train_acgd(data)
                    D_loss += D_loss_items['D']
                    Dr_loss += D_loss_items['Dr']
                    Df_loss += D_loss_items['Df']
                    G_loss += G_loss_item

                if args.bottleneck:
                    if (batch_ndx == 10):
                        return

            losses['D'].append(D_loss / (lenX / args.num_gen))
            losses['Dr'].append(Dr_loss / (lenX / args.num_gen))
            losses['Df'].append(Df_loss / (lenX / args.num_gen))
            losses['G'].append(G_loss / (lenX / args.num_critic))
            if (args.gp): losses['gp'].append(gp_loss / (lenX / args.num_gen))

            print("d loss: " + str(losses['D'][-1]))
            print("g loss: " + str(losses['G'][-1]))
            print("dr loss: " + str(losses['Dr'][-1]))
            print("df loss: " + str(losses['Df'][-1]))

            if (args.gp): print("gp loss: " + str(losses['gp'][-1]))

            gloss = losses['G'][-1]
            drloss = losses['Dr'][-1]
            dfloss = losses['Df'][-1]
            dloss = (drloss + dfloss) / 2

            if (args.bgm):
                if (i > 20 and gloss > dloss + args.bag):
                    print("num gen upping to 10")
                    args.num_gen = 10
                else:
                    print("num gen normal")
                    args.num_gen = temp_ng
            elif (args.gom):
                if (i > 20 and gloss > dloss + args.bag):
                    print("G loss too high - training G only")
                    j = 0
                    print("starting g loss: " + str(gloss))
                    print("starting d loss: " + str(dloss))

                    while (gloss > dloss + args.bag * 0.5):
                        print(j)
                        gloss = 0
                        for l in tqdm(range(lenX)):
                            gloss += train_G()

                        gloss /= lenX
                        print("g loss: " + str(gloss))
                        print("d loss: " + str(dloss))

                        losses['D'].append(dloss * 2)
                        losses['Dr'].append(drloss)
                        losses['Df'].append(dfloss)
                        losses['G'].append(gloss)

                        if (j % 5 == 0):
                            save_outputs.save_sample_outputs(args,
                                                             D,
                                                             G,
                                                             normal_dist,
                                                             args.name,
                                                             i + 1,
                                                             losses,
                                                             k=k,
                                                             j=j)

                        j += 1

                    k += 1
            elif (args.rd):
                if (i > 20 and gloss > dloss + args.bag):
                    print("gloss too high, resetting D params")
                    D.reset_params()

            if ((i + 1) % 5 == 0):
                optimizers = optimizer if args.optimizer == 'acgd' else (
                    D_optimizer, G_optimizer)
                save_outputs.save_models(args, D, G, optimizers, args.name,
                                         i + 1)

            if (args.fid and (i + 1) % 1 == 0):
                losses['fid'].append(
                    evaluation.get_fid(args, C, G, normal_dist, mu2, sigma2))

            if ((i + 1) % 5 == 0):
                save_outputs.save_sample_outputs(args, D, G, normal_dist,
                                                 args.name, i + 1, losses)

    train()
コード例 #4
0
def main(args):
    args = init_dirs(args)

    pt = T.Cartesian() if args.cartesian else T.Polar()

    if args.dataset == 'sp':
        train_dataset = MNISTSuperpixels(args.dataset_path, True, pre_transform=pt)
        test_dataset = MNISTSuperpixels(args.dataset_path, False, pre_transform=pt)
        train_loader = tgDataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
        test_loader = tgDataLoader(test_dataset, batch_size=args.batch_size)
    elif args.dataset == 'sm':
        train_dataset = MNISTGraphDataset(args.dataset_path, args.num_hits, train=True)
        train_loader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size, pin_memory=True)
        test_dataset = MNISTGraphDataset(args.dataset_path, args.num_hits, train=False)
        test_loader = DataLoader(test_dataset, shuffle=True, batch_size=args.batch_size, pin_memory=True)

    if(args.load_model):
        C = torch.load(args.model_path + args.name + "/C_" + str(args.start_epoch) + ".pt").to(device)
    else:
        C = MoNet(args.kernel_size).to(device)

    C_optimizer = torch.optim.Adam(C.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    if(args.scheduler):
        C_scheduler = torch.optim.lr_scheduler.StepLR(C_optimizer, args.decay_step, gamma=args.lr_decay)

    train_losses = []
    test_losses = []

    def plot_losses(epoch, train_losses, test_losses):
        fig = plt.figure()
        ax1 = fig.add_subplot(1, 2, 1)
        ax1.plot(train_losses)
        ax1.set_title('training')
        ax2 = fig.add_subplot(1, 2, 2)
        ax2.plot(test_losses)
        ax2.set_title('testing')

        plt.savefig(args.losses_path + args.name + "/" + str(epoch) + ".png")
        plt.close()

    def save_model(epoch):
        torch.save(C, args.model_path + args.name + "/C_" + str(epoch) + ".pt")

    def train_C(data, y):
        C.train()
        C_optimizer.zero_grad()

        output = C(data)

        # nll_loss takes class labels as target, so one-hot encoding is not needed
        C_loss = F.nll_loss(output, y)

        C_loss.backward()
        C_optimizer.step()

        return C_loss.item()

    def test(epoch):
        C.eval()
        test_loss = 0
        correct = 0
        with torch.no_grad():
            for data in test_loader:
                if args.dataset == 'sp':
                    output = C(data.to(device))
                    y = data.y.to(device)
                elif args.dataset == 'sm':
                    output = C(tg_transform(args, data[0].to(device)))
                    y = data[1].to(device)

                test_loss += F.nll_loss(output, y, size_average=False).item()
                pred = output.data.max(1, keepdim=True)[1]
                correct += pred.eq(y.data.view_as(pred)).sum()

        test_loss /= len(test_loader.dataset)
        test_losses.append(test_loss)

        print('test')

        f = open(args.out_path + args.name + '.txt', 'a')
        print(args.out_path + args.name + '.txt')
        s = "After {} epochs, on test set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(epoch, test_loss, correct, len(test_loader.dataset), 100. * correct / len(test_loader.dataset))
        print(s)
        f.write(s)
        f.close()

    for i in range(args.start_epoch, args.num_epochs):
        print("Epoch %d %s" % ((i + 1), args.name))
        C_loss = 0
        test(i)
        for batch_ndx, data in tqdm(enumerate(train_loader), total=len(train_loader)):
            if args.dataset == 'sp':
                C_loss += train_C(data.to(device), data.y.to(device))
            elif args.dataset == 'sm':
                C_loss += train_C(tg_transform(args, data[0].to(device)), data[1].to(device))

        train_losses.append(C_loss / len(train_loader))

        if(args.scheduler):
            C_scheduler.step()

        if((i + 1) % 10 == 0):
            save_model(i + 1)
            plot_losses(i + 1, train_losses, test_losses)

    test(args.num_epochs)
コード例 #5
0
import setGPU
import torch
import torchvision
import torch.nn as nn
from torch.optim import Adam
from gcn import GCN_classifier
from graph_dataset_mnist import MNISTGraphDataset

batch_size = 128
num_thresholded = 100

transforms = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor()])

X_test = MNISTGraphDataset(num_thresholded, train=False)
# X_train = MNISTGraphDataset(num_thresholded, train=True)

# X_train_loaded = torch.utils.data.DataLoader(X_train, shuffle=True, batch_size=batch_size)
X_test_loaded = torch.utils.data.DataLoader(X_test,
                                            shuffle=False,
                                            batch_size=batch_size)

model = GCN_classifier(3, 256, 10, 0.3)

loss_func = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.001)

for e in range(5):
    print(e)
    running_loss = 0
    for i, data in enumerate(X_test_loaded):