示例#1
0
def get_test_imgs(data):
    data.set_split_index(2)
    data_loader = DataLoader(data,
                             batch_size=len(data),
                             shuffle=False,
                             sampler=SequentialSampler(data))
    for ibatch, data_batch in enumerate(data_loader):
        recipe_ids, recipe_embs, img_ids, imgs, classes, _, _ = data_batch
        batch_size, recipe_embs, imgs, = util.get_variables3(
            recipe_ids, recipe_embs, img_ids, imgs)
        return imgs.detach()
示例#2
0
def get_img_gen(data, split_index, G, iepoch, out_path):
    old_split_index = data.split_index
    data.set_split_index(split_index)
    data_loader = torch.utils.data.DataLoader(data, batch_size=1, shuffle=False)
    data_batch = next(iter(data_loader))
    with torch.no_grad():
        recipe_ids, recipe_embs, img_ids, imgs, classes, _, _ = data_batch
        batch_size, recipe_embs, imgs = util.get_variables3(recipe_ids, recipe_embs, img_ids, imgs)
        z = torch.randn(batch_size, opts.LATENT_SIZE).to(opts.DEVICE)
        imgs_gen = G(z, recipe_embs)
        save_img(imgs_gen[0], iepoch, out_path, split_index, recipe_ids[0], img_ids[0])
    data.set_split_index(old_split_index)
示例#3
0
def get_val_imgs(G, data):
    data.set_split_index(1)  # Set to validation split
    data_loader = DataLoader(data,
                             batch_size=len(data),
                             shuffle=False,
                             sampler=SequentialSampler(data))
    for ibatch, data_batch in enumerate(data_loader):
        with torch.no_grad():
            recipe_ids, recipe_embs, img_ids, imgs, classes, _, _ = data_batch
            batch_size, recipe_embs, imgs, = util.get_variables3(
                recipe_ids, recipe_embs, img_ids, imgs)
            z = torch.randn(batch_size, opts.LATENT_SIZE).to(opts.DEVICE)
            imgs_gen = G(z, recipe_embs)
            return imgs.detach(), imgs_gen.detach()
示例#4
0
def main():
    if len(sys.argv) < 7:
        print(
            'Usage: python3 sample.py [MODEL_PATH] [DATA_PATH] [SPLIT_INDEX] [OUT_PATH] [RECIPE_ID] [NUM_SAMPLES]'
        )
        exit()
    model_path = os.path.abspath(sys.argv[1])
    data_path = os.path.abspath(sys.argv[2])
    split_index = int(sys.argv[3])
    out_path = os.path.abspath(sys.argv[4])
    recipe_id = sys.argv[5]
    num_samples = int(sys.argv[6])

    util.create_dir(out_path)
    saved_model = torch.load(model_path)

    data = GANstronomyDataset(data_path, split=opts.TVT_SPLIT)
    data.set_split_index(split_index)
    data_loader = DataLoader(data,
                             batch_size=1,
                             shuffle=False,
                             sampler=SequentialSampler(data))

    num_classes = data.num_classes()
    G = Generator(opts.LATENT_SIZE, opts.EMBED_SIZE).to(opts.DEVICE)
    G.load_state_dict(saved_model['G_state_dict'])
    G.eval()

    embs = None

    for ibatch, data_batch in enumerate(data_loader):
        with torch.no_grad():
            recipe_ids, recipe_embs, img_ids, imgs, classes, _, _ = data_batch
            batch_size, recipe_embs, imgs = util.get_variables3(
                recipe_ids, recipe_embs, img_ids, imgs)
            if recipe_ids[0] == recipe_id:
                embs = recipe_embs
                break

    assert embs is not None

    z = torch.randn(num_samples, opts.LATENT_SIZE).to(opts.DEVICE)
    imgs_gen = G(z, embs.expand(num_samples, opts.EMBED_SIZE)).detach()
    for i in range(num_samples):
        img_gen = imgs_gen[i]
        save_results(out_path, recipe_id, img_gen, i)
示例#5
0
def main():
    if len(sys.argv) < 7:
        print('Usage: python3 interp.py [MODEL_PATH] [DATA_PATH] [SPLIT_INDEX] [OUT_PATH] [RECIPE_ID] [NUM_DIV]')
        exit()
    model_path = os.path.abspath(sys.argv[1])
    data_path = os.path.abspath(sys.argv[2])
    split_index = int(sys.argv[3])
    out_path = os.path.abspath(sys.argv[4])
    recipe_id = sys.argv[5]
    num_div = int(sys.argv[6])
    
    util.create_dir(out_path)
    saved_model = torch.load(model_path)

    data = GANstronomyDataset(data_path, split=opts.TVT_SPLIT)
    data.set_split_index(split_index)
    data_loader = DataLoader(data, batch_size=1, shuffle=False, sampler=SequentialSampler(data))

    num_classes = data.num_classes()
    G = Generator(opts.LATENT_SIZE, opts.EMBED_SIZE).to(opts.DEVICE)
    G.load_state_dict(saved_model['G_state_dict'])
    G.eval()

    embs = None

    for ibatch, data_batch in enumerate(data_loader):
        with torch.no_grad():
            recipe_ids, recipe_embs, img_ids, imgs, classes, _, _ = data_batch
            batch_size, recipe_embs, imgs  = util.get_variables3(recipe_ids, recipe_embs, img_ids, imgs)
            if recipe_ids[0] == recipe_id:
                embs = recipe_embs
                break

    assert embs is not None
    z1 = torch.randn(1, opts.LATENT_SIZE).to(opts.DEVICE)
    z2 = torch.randn(1, opts.LATENT_SIZE).to(opts.DEVICE)
    for a in np.linspace(0.0, 1.0, num_div + 1):
        a = torch.tensor(a, dtype=torch.float)
        a = Variable(a.type(FloatTensor)).to(opts.DEVICE)
        z = (1.0 - a) * z1 + a * z2
        img_gen = G(z, embs).detach()[0]
        save_results(out_path, recipe_id, img_gen, a)
示例#6
0
def main():
    if len(sys.argv) < 5:
        print(
            'Usage: python3 test.py [MODEL_PATH] [DATA_PATH] [SPLIT_INDEX] [OUT_PATH]'
        )
        exit()
    model_path = os.path.abspath(sys.argv[1])
    data_path = os.path.abspath(sys.argv[2])
    split_index = int(sys.argv[3])
    out_path = os.path.abspath(sys.argv[4])

    util.create_dir(out_path)
    saved_model = torch.load(model_path)

    data = GANstronomyDataset(data_path, split=opts.TVT_SPLIT)
    data.set_split_index(split_index)
    data_loader = DataLoader(data,
                             batch_size=opts.BATCH_SIZE,
                             shuffle=False,
                             sampler=SequentialSampler(data))

    G = Generator(opts.LATENT_SIZE, opts.EMBED_SIZE).to(opts.DEVICE)
    G.load_state_dict(saved_model['G_state_dict'])
    G.eval()

    all_ingrs = util.load_ingredients()

    for ibatch, data_batch in enumerate(data_loader):
        with torch.no_grad():
            recipe_ids, recipe_embs, img_ids, imgs, classes, _, _ = data_batch
            batch_size, recipe_embs, imgs, = util.get_variables3(
                recipe_ids, recipe_embs, img_ids, imgs)
            z = torch.randn(batch_size, opts.LATENT_SIZE).to(opts.DEVICE)
            imgs_gen = G(z, recipe_embs)
            imgs, imgs_gen = imgs.detach(), imgs_gen.detach()
            for iexample in range(batch_size):
                save_results(all_ingrs, imgs[iexample], imgs_gen[iexample],
                             img_ids[iexample], recipe_ids[iexample], out_path)
示例#7
0
def main():
    # Load the data
    data = GANstronomyDataset(opts.DATA_PATH, split=opts.TVT_SPLIT)
    data.set_split_index(0)
    data_loader = torch.utils.data.DataLoader(data,
                                              batch_size=opts.BATCH_SIZE,
                                              shuffle=True)

    # Make the output directory
    util.create_dir(opts.RUN_PATH)
    util.create_dir(opts.IMG_OUT_PATH)
    util.create_dir(opts.MODEL_OUT_PATH)

    # Copy opts.py and model.py to opts.RUN_PATH as a record
    shutil.copy2('opts.py', opts.RUN_PATH)
    shutil.copy2('model.py', opts.RUN_PATH)
    shutil.copy2('train.py', opts.RUN_PATH)
    
    # Instantiate the models
    G = Generator(opts.LATENT_SIZE, opts.EMBED_SIZE).to(opts.DEVICE)
    G_optimizer = torch.optim.Adam(G.parameters(), lr=opts.ADAM_LR, betas=opts.ADAM_B)

    D = Discriminator(opts.EMBED_SIZE).to(opts.DEVICE)
    D_optimizer = torch.optim.Adam(D.parameters(), lr=opts.ADAM_LR, betas=opts.ADAM_B)

    if opts.MODEL_PATH is None:
        start_iepoch, start_ibatch = 0, 0
    else:
        print('Attempting to resume training using model in %s...' % opts.MODEL_PATH)
        start_iepoch, start_ibatch = load_state_dicts(opts.MODEL_PATH, G, G_optimizer, D, D_optimizer)
    
    for iepoch in range(opts.NUM_EPOCHS):
        for ibatch, data_batch in enumerate(data_loader):
            # To try to resume training, just continue if iepoch and ibatch are less than their starts
            if iepoch < start_iepoch or (iepoch == start_iepoch and ibatch < start_ibatch):
                if iepoch % opts.INTV_PRINT_LOSS == 0 and not ibatch:
                    print('Skipping epoch %d...' % iepoch)
                continue
            
            recipe_ids, recipe_embs, img_ids, imgs, classes, noisy_real, noisy_fake = data_batch
            noisy_real, noisy_fake = util.get_variables2(noisy_real, noisy_fake)

            # Make sure we're not training on validation or test data!
            if opts.SAFETY_MODE:
                for recipe_id in recipe_ids:
                    assert data.get_recipe_split_index(recipe_id) == 0

            batch_size, recipe_embs, imgs = util.get_variables3(recipe_ids, recipe_embs, img_ids, imgs)

            # Adversarial ground truths
            all_real = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False).to(opts.DEVICE)
            # all_fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False).to(opts.DEVICE)

            z = torch.randn(batch_size, opts.LATENT_SIZE).to(opts.DEVICE)
            # D_loss, G_loss = wasserstein_loss(G, D, imgs, recipe_embs)
            
            # Train Discriminator
            z = torch.randn(batch_size, opts.LATENT_SIZE).to(opts.DEVICE)
            imgs_gen = G(z, recipe_embs)
            for _ in range(opts.NUM_UPDATE_D):
                D_optimizer.zero_grad()
                fake_probs = D(imgs_gen.detach(), recipe_embs)
                real_probs = D(imgs, recipe_embs)
                D_loss = BCELoss(fake_probs, noisy_fake) + BCELoss(real_probs, noisy_real)
                D_loss.backward(retain_graph=True)
                D_optimizer.step()

            # Train Generator
            G_optimizer.zero_grad()
            fake_probs = D(imgs_gen, recipe_embs)
            G_loss = BCELoss(fake_probs, all_real)
            G_loss.backward()
            G_optimizer.step()
            
            if iepoch % opts.INTV_PRINT_LOSS == 0 and not ibatch:
                print_loss(G_loss, D_loss, iepoch)
            if iepoch % opts.INTV_SAVE_IMG == 0 and not ibatch:
                # Save a training image
                get_img_gen(data, 0, G, iepoch, opts.IMG_OUT_PATH)
                # Save a validation image
                get_img_gen(data, 1, G, iepoch, opts.IMG_OUT_PATH)
            if iepoch % opts.INTV_SAVE_MODEL == 0 and not ibatch:
                print('Saving model...')
                save_model(G, G_optimizer, D, D_optimizer, iepoch, opts.MODEL_OUT_PATH)

    save_model(G, G_optimizer, D, D_optimizer, 'FINAL', opts.MODEL_OUT_PATH)
    print('\a') # Ring the bell to alert the human