Example #1
0
def peek_at_classes():
    dataset = PokemonDataset()
    classes = dataset.get_classes()
    preds = pd.read_csv(PREDS_PATH)

    for cls in classes:

        # cls_preds = pd.concat((preds[cls].sort_values()[-5:][::-1], preds[cls].sort_values()[:5]))
        cls_preds = preds[cls].sort_values()[-10:][::-1]
        fig, axs = plt.subplots(2, 5, figsize=(10, 5))

        for (idx, prob), ax in zip(cls_preds.iteritems(), axs.flatten()):
            img, labels = dataset[idx]
            labels = ', '.join([l for l in labels if l])

            ax.imshow(img)
            ax.set_title(str(round(prob, 4)) + '\n' + labels)
            ax.axis('off')

        fig.tight_layout()
        plt.savefig(os.path.join("figs", f"{cls}_preds_preview.png"))
Example #2
0
def train():
    ###################
    #    Load Data    #
    ###################

    dataset = PokemonDataset(add_mirrored=True)
    data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=SHUFFLE)

    ###################
    #  Set Up Models  #
    ###################

    auto_encoder = models.AutoEncoder()

    optimizer = optim.Adam(auto_encoder.parameters(), lr=LEARNING_RATE)
    criterion = torch.nn.MSELoss()

    ###################
    #    Training     #
    ###################

    auto_encoder.train()
    log = Logger('train.log', TrainingLogTemplate())

    for epoch in range(EPOCHS):
        epoch_loss = 0
        ctx = {}
        for step, (x, _) in enumerate(data_loader):
            # Train
            y = auto_encoder(x)
            optimizer.zero_grad()
            loss = criterion(y, x)
            loss.backward()
            optimizer.step()

            # Display status
            epoch_loss += loss.item()
            ctx = {
                'epoch': epoch + 1,
                'epochs': EPOCHS,
                'step': step + 1,
                'loss': epoch_loss / (step + 1),
                'data_len': len(data_loader),
            }
            log.write(ctx, overwrite=True)
        log.write(ctx)
    log.close()

    ###################
    #   Save Models   #
    ###################

    auto_encoder.save_states()
Example #3
0
def dataset_preview():
    data = PokemonDataset().fetch_per_type_examples()
    fig, axs = plt.subplots(3, 6, figsize=(14, 7))

    for datum, ax in zip(data, axs.flatten()):
        img, cls = datum
        ax.imshow(img)
        ax.set_title(
            cls[0] if type(cls[1]) is float else f"{cls[0]}, {cls[1]}")
        ax.axis('off')

    fig.tight_layout()
    plt.savefig(os.path.join("figs", "dataset_preview.png"))
Example #4
0
def test():
    ###################
    #    Load Data    #
    ###################

    dataset = PokemonDataset()
    data_loader = DataLoader(dataset)

    ###################
    #  Set Up Models  #
    ###################

    auto_encoder = models.AutoEncoder()
    auto_encoder.load_states()

    ###################
    #     Testing     #
    ###################

    auto_encoder.eval()
    log = Logger('test.log', TestingLogTemplate())

    total_loss = 0
    for step, (x, _) in enumerate(data_loader):
        # Train
        y = auto_encoder(x)
        loss = F.mse_loss(y, x)

        # Display status
        total_loss += loss.item()
        ctx = {
            'step': step + 1,
            'loss': total_loss / (step + 1),
            'data_len': len(data_loader),
        }
        log.write(ctx, overwrite=True)
    log.close()

    ###################
    #     Visuals     #
    ###################

    # Get inputs and outputs for nine pokemon
    samples = []
    for x, t in dataset[:10]:
        y = auto_encoder(x.unsqueeze(0))
        samples.append((x.squeeze(), y.detach().squeeze()))

    # Display inputs and outputs
    visualize_input_output(samples, save=True)
Example #5
0
def peek_at_pokemons():
    pkm_idxs, cls = [170, 256, 53, 86, 115], "Water"
    # pkm_idxs, cls = [387, 125, 76, 145], "Fire"
    dataset = PokemonDataset()

    fig, axs = plt.subplots(2, len(pkm_idxs), figsize=(14, 6))

    for idx, pkm_idx in enumerate(pkm_idxs):
        pkm_img, pkm_labs = dataset[pkm_idx]
        pkm_labs = ', '.join([l for l in pkm_labs if l])
        pred = pd.read_csv(PREDS_PATH).iloc[pkm_idx]
        pred_clses = pred.keys()[1:]
        pred_vals = pred.to_numpy()[1:]

        axs[0, idx].imshow(pkm_img)
        axs[0, idx].set_title(pkm_labs)
        axs[0, idx].axis('off')

        axs[1, idx].barh(pred_clses, pred_vals, align='center')
        axs[1, idx].set_yticks(np.arange(len(pred_clses)))
        axs[1, idx].set_yticklabels(pred_clses)

    fig.tight_layout()
    plt.savefig(os.path.join("figs", f"{cls}_pokemons_preview.png"))
Example #6
0
import torch
import clip
from data import PokemonDataset
from tqdm import tqdm
import numpy as np
import pandas as pd

from collections import defaultdict

save_preds = True
device = "cpu" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.0005, momentum=0.5)

dataset = PokemonDataset()
classes = list(dataset.get_classes())
text = clip.tokenize(classes).to(device)

examples = dataset.fetch_per_type_examples()
for _ in tqdm(range(5)):
    optimizer.zero_grad()

    labels = []
    images = []
    for ex in examples:
        image, label = ex
        labels.append(classes.index(label))
        image_tensor = preprocess(image).unsqueeze(0).to(device)
        images.append(image_tensor)
Example #7
0
from data import PokemonDataset

dataset = PokemonDataset()
classes = dataset.get_classes()
for image, label in dataset:
    print(image, label)
Example #8
0
def train():
    ###################
    #    Load Data    #
    ###################

    dataset = PokemonDataset(add_mirrored=True)
    data_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=SHUFFLE)

    ###################
    #  Set Up Models  #
    ###################

    generator = models.Generator()
    descriminator = models.Descriminator()

    g_opt = optim.Adam(generator.parameters(), lr=LEARNING_RATE)
    d_opt = optim.SGD(descriminator.parameters(), lr=LEARNING_RATE)
    criterion = torch.nn.BCELoss()

    ###################
    #    Training     #
    ###################

    generator.train()
    descriminator.train()
    template_str = '[{train_type}]  Epoch: {epoch:{epochs_str_len}d}/{epochs}  Item: {step}/{data_len}  Loss: {loss:.4f}'
    log = Logger('train.log', LogTemplate(template_str))

    for epoch in range(EPOCHS):
        d_pass = True
        epoch_loss = 0
        ctx = {}
        for _ in range(2):
            for step, (real_img, _) in enumerate(data_loader):
                current_batch_size = real_img.shape[0]

                # Generate image(s)
                z = torch.randn(current_batch_size, 1, 8, 8)
                fake_img = generator(z)

                # Descriminator target values
                real_t = torch.ones((current_batch_size, 1))
                fake_t = torch.zeros((current_batch_size, 1))

                # Train
                if d_pass:
                    # Descriminator training pass
                    train_type = 'D'
                    real_y = descriminator(real_img)
                    fake_y = descriminator(fake_img)
                    d_opt.zero_grad()
                    real_loss = criterion(real_y, real_t)
                    fake_loss = criterion(fake_y, fake_t)
                    loss = (real_loss + fake_loss) / 2
                    loss.backward()
                    d_opt.step()
                    epoch_loss += loss.item()
                else:
                    # Generator training pass
                    train_type = 'G'
                    y = descriminator(fake_img)
                    g_opt.zero_grad()
                    loss = criterion(y, real_t)
                    loss.backward()
                    g_opt.step()
                    epoch_loss += loss.item()

                # Display status
                ctx = {
                    'train_type': train_type,
                    'epoch': epoch + 1,
                    'epochs': EPOCHS,
                    'epochs_str_len': len(str(EPOCHS)),
                    'step': step + 1,
                    'loss': epoch_loss / (step + 1),
                    'data_len': len(data_loader),
                }
                log.write(ctx, overwrite=True)
            d_pass = False
            log.write(ctx)
    log.close()

    ###################
    #   Save Model    #
    ###################

    torch.save(generator.state_dict(), 'states/generator.pt')
Example #9
0
import torch
import clip
from tqdm import tqdm
from data import PokemonDataset
from collections import defaultdict

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

print(device)
dataset = PokemonDataset()
classes = list(dataset.get_classes())

proper_one = defaultdict(int)
proper_two = defaultdict(int)
all_one = defaultdict(int)
all_two = defaultdict(int)
for image, real_labels in tqdm(dataset):
    image_tensor = preprocess(image).unsqueeze(0).to(device)
    text = clip.tokenize(classes).to(device)

    with torch.no_grad():

        logits_per_image, logits_per_text = model(image_tensor, text)
        probs = logits_per_image.softmax(dim=-1).cpu().numpy()
    if real_labels[1] is None:
        label = real_labels[0]
        all_one[label] += 1
        if real_labels[0] == classes[probs.argmax()]:
            proper_one[label] += 1
    else:
Example #10
0
def clean_pokedex():
    pms = pd.DataFrame([[type1, type2]
                        for _, [type1, type2] in PokemonDataset()],
                       columns=["Type1", "Type2"])
    pms.to_csv(POKEDEX_PATH)