def get_test_imgs(data): data.set_split_index(2) data_loader = DataLoader(data, batch_size=len(data), shuffle=False, sampler=SequentialSampler(data)) for ibatch, data_batch in enumerate(data_loader): recipe_ids, recipe_embs, img_ids, imgs, classes, _, _ = data_batch batch_size, recipe_embs, imgs, = util.get_variables3( recipe_ids, recipe_embs, img_ids, imgs) return imgs.detach()
def get_img_gen(data, split_index, G, iepoch, out_path): old_split_index = data.split_index data.set_split_index(split_index) data_loader = torch.utils.data.DataLoader(data, batch_size=1, shuffle=False) data_batch = next(iter(data_loader)) with torch.no_grad(): recipe_ids, recipe_embs, img_ids, imgs, classes, _, _ = data_batch batch_size, recipe_embs, imgs = util.get_variables3(recipe_ids, recipe_embs, img_ids, imgs) z = torch.randn(batch_size, opts.LATENT_SIZE).to(opts.DEVICE) imgs_gen = G(z, recipe_embs) save_img(imgs_gen[0], iepoch, out_path, split_index, recipe_ids[0], img_ids[0]) data.set_split_index(old_split_index)
def get_val_imgs(G, data): data.set_split_index(1) # Set to validation split data_loader = DataLoader(data, batch_size=len(data), shuffle=False, sampler=SequentialSampler(data)) for ibatch, data_batch in enumerate(data_loader): with torch.no_grad(): recipe_ids, recipe_embs, img_ids, imgs, classes, _, _ = data_batch batch_size, recipe_embs, imgs, = util.get_variables3( recipe_ids, recipe_embs, img_ids, imgs) z = torch.randn(batch_size, opts.LATENT_SIZE).to(opts.DEVICE) imgs_gen = G(z, recipe_embs) return imgs.detach(), imgs_gen.detach()
def main(): if len(sys.argv) < 7: print( 'Usage: python3 sample.py [MODEL_PATH] [DATA_PATH] [SPLIT_INDEX] [OUT_PATH] [RECIPE_ID] [NUM_SAMPLES]' ) exit() model_path = os.path.abspath(sys.argv[1]) data_path = os.path.abspath(sys.argv[2]) split_index = int(sys.argv[3]) out_path = os.path.abspath(sys.argv[4]) recipe_id = sys.argv[5] num_samples = int(sys.argv[6]) util.create_dir(out_path) saved_model = torch.load(model_path) data = GANstronomyDataset(data_path, split=opts.TVT_SPLIT) data.set_split_index(split_index) data_loader = DataLoader(data, batch_size=1, shuffle=False, sampler=SequentialSampler(data)) num_classes = data.num_classes() G = Generator(opts.LATENT_SIZE, opts.EMBED_SIZE).to(opts.DEVICE) G.load_state_dict(saved_model['G_state_dict']) G.eval() embs = None for ibatch, data_batch in enumerate(data_loader): with torch.no_grad(): recipe_ids, recipe_embs, img_ids, imgs, classes, _, _ = data_batch batch_size, recipe_embs, imgs = util.get_variables3( recipe_ids, recipe_embs, img_ids, imgs) if recipe_ids[0] == recipe_id: embs = recipe_embs break assert embs is not None z = torch.randn(num_samples, opts.LATENT_SIZE).to(opts.DEVICE) imgs_gen = G(z, embs.expand(num_samples, opts.EMBED_SIZE)).detach() for i in range(num_samples): img_gen = imgs_gen[i] save_results(out_path, recipe_id, img_gen, i)
def main(): if len(sys.argv) < 7: print('Usage: python3 interp.py [MODEL_PATH] [DATA_PATH] [SPLIT_INDEX] [OUT_PATH] [RECIPE_ID] [NUM_DIV]') exit() model_path = os.path.abspath(sys.argv[1]) data_path = os.path.abspath(sys.argv[2]) split_index = int(sys.argv[3]) out_path = os.path.abspath(sys.argv[4]) recipe_id = sys.argv[5] num_div = int(sys.argv[6]) util.create_dir(out_path) saved_model = torch.load(model_path) data = GANstronomyDataset(data_path, split=opts.TVT_SPLIT) data.set_split_index(split_index) data_loader = DataLoader(data, batch_size=1, shuffle=False, sampler=SequentialSampler(data)) num_classes = data.num_classes() G = Generator(opts.LATENT_SIZE, opts.EMBED_SIZE).to(opts.DEVICE) G.load_state_dict(saved_model['G_state_dict']) G.eval() embs = None for ibatch, data_batch in enumerate(data_loader): with torch.no_grad(): recipe_ids, recipe_embs, img_ids, imgs, classes, _, _ = data_batch batch_size, recipe_embs, imgs = util.get_variables3(recipe_ids, recipe_embs, img_ids, imgs) if recipe_ids[0] == recipe_id: embs = recipe_embs break assert embs is not None z1 = torch.randn(1, opts.LATENT_SIZE).to(opts.DEVICE) z2 = torch.randn(1, opts.LATENT_SIZE).to(opts.DEVICE) for a in np.linspace(0.0, 1.0, num_div + 1): a = torch.tensor(a, dtype=torch.float) a = Variable(a.type(FloatTensor)).to(opts.DEVICE) z = (1.0 - a) * z1 + a * z2 img_gen = G(z, embs).detach()[0] save_results(out_path, recipe_id, img_gen, a)
def main(): if len(sys.argv) < 5: print( 'Usage: python3 test.py [MODEL_PATH] [DATA_PATH] [SPLIT_INDEX] [OUT_PATH]' ) exit() model_path = os.path.abspath(sys.argv[1]) data_path = os.path.abspath(sys.argv[2]) split_index = int(sys.argv[3]) out_path = os.path.abspath(sys.argv[4]) util.create_dir(out_path) saved_model = torch.load(model_path) data = GANstronomyDataset(data_path, split=opts.TVT_SPLIT) data.set_split_index(split_index) data_loader = DataLoader(data, batch_size=opts.BATCH_SIZE, shuffle=False, sampler=SequentialSampler(data)) G = Generator(opts.LATENT_SIZE, opts.EMBED_SIZE).to(opts.DEVICE) G.load_state_dict(saved_model['G_state_dict']) G.eval() all_ingrs = util.load_ingredients() for ibatch, data_batch in enumerate(data_loader): with torch.no_grad(): recipe_ids, recipe_embs, img_ids, imgs, classes, _, _ = data_batch batch_size, recipe_embs, imgs, = util.get_variables3( recipe_ids, recipe_embs, img_ids, imgs) z = torch.randn(batch_size, opts.LATENT_SIZE).to(opts.DEVICE) imgs_gen = G(z, recipe_embs) imgs, imgs_gen = imgs.detach(), imgs_gen.detach() for iexample in range(batch_size): save_results(all_ingrs, imgs[iexample], imgs_gen[iexample], img_ids[iexample], recipe_ids[iexample], out_path)
def main(): # Load the data data = GANstronomyDataset(opts.DATA_PATH, split=opts.TVT_SPLIT) data.set_split_index(0) data_loader = torch.utils.data.DataLoader(data, batch_size=opts.BATCH_SIZE, shuffle=True) # Make the output directory util.create_dir(opts.RUN_PATH) util.create_dir(opts.IMG_OUT_PATH) util.create_dir(opts.MODEL_OUT_PATH) # Copy opts.py and model.py to opts.RUN_PATH as a record shutil.copy2('opts.py', opts.RUN_PATH) shutil.copy2('model.py', opts.RUN_PATH) shutil.copy2('train.py', opts.RUN_PATH) # Instantiate the models G = Generator(opts.LATENT_SIZE, opts.EMBED_SIZE).to(opts.DEVICE) G_optimizer = torch.optim.Adam(G.parameters(), lr=opts.ADAM_LR, betas=opts.ADAM_B) D = Discriminator(opts.EMBED_SIZE).to(opts.DEVICE) D_optimizer = torch.optim.Adam(D.parameters(), lr=opts.ADAM_LR, betas=opts.ADAM_B) if opts.MODEL_PATH is None: start_iepoch, start_ibatch = 0, 0 else: print('Attempting to resume training using model in %s...' % opts.MODEL_PATH) start_iepoch, start_ibatch = load_state_dicts(opts.MODEL_PATH, G, G_optimizer, D, D_optimizer) for iepoch in range(opts.NUM_EPOCHS): for ibatch, data_batch in enumerate(data_loader): # To try to resume training, just continue if iepoch and ibatch are less than their starts if iepoch < start_iepoch or (iepoch == start_iepoch and ibatch < start_ibatch): if iepoch % opts.INTV_PRINT_LOSS == 0 and not ibatch: print('Skipping epoch %d...' % iepoch) continue recipe_ids, recipe_embs, img_ids, imgs, classes, noisy_real, noisy_fake = data_batch noisy_real, noisy_fake = util.get_variables2(noisy_real, noisy_fake) # Make sure we're not training on validation or test data! if opts.SAFETY_MODE: for recipe_id in recipe_ids: assert data.get_recipe_split_index(recipe_id) == 0 batch_size, recipe_embs, imgs = util.get_variables3(recipe_ids, recipe_embs, img_ids, imgs) # Adversarial ground truths all_real = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False).to(opts.DEVICE) # all_fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False).to(opts.DEVICE) z = torch.randn(batch_size, opts.LATENT_SIZE).to(opts.DEVICE) # D_loss, G_loss = wasserstein_loss(G, D, imgs, recipe_embs) # Train Discriminator z = torch.randn(batch_size, opts.LATENT_SIZE).to(opts.DEVICE) imgs_gen = G(z, recipe_embs) for _ in range(opts.NUM_UPDATE_D): D_optimizer.zero_grad() fake_probs = D(imgs_gen.detach(), recipe_embs) real_probs = D(imgs, recipe_embs) D_loss = BCELoss(fake_probs, noisy_fake) + BCELoss(real_probs, noisy_real) D_loss.backward(retain_graph=True) D_optimizer.step() # Train Generator G_optimizer.zero_grad() fake_probs = D(imgs_gen, recipe_embs) G_loss = BCELoss(fake_probs, all_real) G_loss.backward() G_optimizer.step() if iepoch % opts.INTV_PRINT_LOSS == 0 and not ibatch: print_loss(G_loss, D_loss, iepoch) if iepoch % opts.INTV_SAVE_IMG == 0 and not ibatch: # Save a training image get_img_gen(data, 0, G, iepoch, opts.IMG_OUT_PATH) # Save a validation image get_img_gen(data, 1, G, iepoch, opts.IMG_OUT_PATH) if iepoch % opts.INTV_SAVE_MODEL == 0 and not ibatch: print('Saving model...') save_model(G, G_optimizer, D, D_optimizer, iepoch, opts.MODEL_OUT_PATH) save_model(G, G_optimizer, D, D_optimizer, 'FINAL', opts.MODEL_OUT_PATH) print('\a') # Ring the bell to alert the human