def main(): if len(sys.argv) < 7: print( 'Usage: python3 sample.py [MODEL_PATH] [DATA_PATH] [SPLIT_INDEX] [OUT_PATH] [RECIPE_ID] [NUM_SAMPLES]' ) exit() model_path = os.path.abspath(sys.argv[1]) data_path = os.path.abspath(sys.argv[2]) split_index = int(sys.argv[3]) out_path = os.path.abspath(sys.argv[4]) recipe_id = sys.argv[5] num_samples = int(sys.argv[6]) util.create_dir(out_path) saved_model = torch.load(model_path) data = GANstronomyDataset(data_path, split=opts.TVT_SPLIT) data.set_split_index(split_index) data_loader = DataLoader(data, batch_size=1, shuffle=False, sampler=SequentialSampler(data)) num_classes = data.num_classes() G = Generator(opts.LATENT_SIZE, opts.EMBED_SIZE).to(opts.DEVICE) G.load_state_dict(saved_model['G_state_dict']) G.eval() embs = None for ibatch, data_batch in enumerate(data_loader): with torch.no_grad(): recipe_ids, recipe_embs, img_ids, imgs, classes, _, _ = data_batch batch_size, recipe_embs, imgs = util.get_variables3( recipe_ids, recipe_embs, img_ids, imgs) if recipe_ids[0] == recipe_id: embs = recipe_embs break assert embs is not None z = torch.randn(num_samples, opts.LATENT_SIZE).to(opts.DEVICE) imgs_gen = G(z, embs.expand(num_samples, opts.EMBED_SIZE)).detach() for i in range(num_samples): img_gen = imgs_gen[i] save_results(out_path, recipe_id, img_gen, i)
def main(): if len(sys.argv) < 7: print('Usage: python3 interp.py [MODEL_PATH] [DATA_PATH] [SPLIT_INDEX] [OUT_PATH] [RECIPE_ID] [NUM_DIV]') exit() model_path = os.path.abspath(sys.argv[1]) data_path = os.path.abspath(sys.argv[2]) split_index = int(sys.argv[3]) out_path = os.path.abspath(sys.argv[4]) recipe_id = sys.argv[5] num_div = int(sys.argv[6]) util.create_dir(out_path) saved_model = torch.load(model_path) data = GANstronomyDataset(data_path, split=opts.TVT_SPLIT) data.set_split_index(split_index) data_loader = DataLoader(data, batch_size=1, shuffle=False, sampler=SequentialSampler(data)) num_classes = data.num_classes() G = Generator(opts.LATENT_SIZE, opts.EMBED_SIZE).to(opts.DEVICE) G.load_state_dict(saved_model['G_state_dict']) G.eval() embs = None for ibatch, data_batch in enumerate(data_loader): with torch.no_grad(): recipe_ids, recipe_embs, img_ids, imgs, classes, _, _ = data_batch batch_size, recipe_embs, imgs = util.get_variables3(recipe_ids, recipe_embs, img_ids, imgs) if recipe_ids[0] == recipe_id: embs = recipe_embs break assert embs is not None z1 = torch.randn(1, opts.LATENT_SIZE).to(opts.DEVICE) z2 = torch.randn(1, opts.LATENT_SIZE).to(opts.DEVICE) for a in np.linspace(0.0, 1.0, num_div + 1): a = torch.tensor(a, dtype=torch.float) a = Variable(a.type(FloatTensor)).to(opts.DEVICE) z = (1.0 - a) * z1 + a * z2 img_gen = G(z, embs).detach()[0] save_results(out_path, recipe_id, img_gen, a)
def main(): if len(sys.argv) < 5: print( 'Usage: python3 test.py [MODEL_PATH] [DATA_PATH] [SPLIT_INDEX] [OUT_PATH]' ) exit() model_path = os.path.abspath(sys.argv[1]) data_path = os.path.abspath(sys.argv[2]) split_index = int(sys.argv[3]) out_path = os.path.abspath(sys.argv[4]) util.create_dir(out_path) saved_model = torch.load(model_path) data = GANstronomyDataset(data_path, split=opts.TVT_SPLIT) data.set_split_index(split_index) data_loader = DataLoader(data, batch_size=opts.BATCH_SIZE, shuffle=False, sampler=SequentialSampler(data)) G = Generator(opts.LATENT_SIZE, opts.EMBED_SIZE).to(opts.DEVICE) G.load_state_dict(saved_model['G_state_dict']) G.eval() all_ingrs = util.load_ingredients() for ibatch, data_batch in enumerate(data_loader): with torch.no_grad(): recipe_ids, recipe_embs, img_ids, imgs, classes, _, _ = data_batch batch_size, recipe_embs, imgs, = util.get_variables3( recipe_ids, recipe_embs, img_ids, imgs) z = torch.randn(batch_size, opts.LATENT_SIZE).to(opts.DEVICE) imgs_gen = G(z, recipe_embs) imgs, imgs_gen = imgs.detach(), imgs_gen.detach() for iexample in range(batch_size): save_results(all_ingrs, imgs[iexample], imgs_gen[iexample], img_ids[iexample], recipe_ids[iexample], out_path)
def main(): if len(sys.argv) < 3: print('Usage: python3 score.py [MODEL_PATH] [DATA_PATH]') exit() model_path = os.path.abspath(sys.argv[1]) data_path = os.path.abspath(sys.argv[2]) saved_model = torch.load(model_path) data = GANstronomyDataset(data_path, split=opts.TVT_SPLIT) G = Generator(opts.LATENT_SIZE, opts.EMBED_SIZE).to(opts.DEVICE) G.load_state_dict(saved_model['G_state_dict']) G.eval() val_imgs, val_imgs_gen = get_val_imgs(G, data) test_imgs = get_test_imgs(data) print('FID(test_real, val_real): %f' % util.get_fid(test_imgs, val_imgs)) print('FID(test_real, val_fake): %f' % util.get_fid(test_imgs, val_imgs_gen))
def main(): # Load the data data = GANstronomyDataset(opts.DATA_PATH, split=opts.TVT_SPLIT) data.set_split_index(0) data_loader = torch.utils.data.DataLoader(data, batch_size=opts.BATCH_SIZE, shuffle=True) num_classes = data.num_classes() # Make the output directory util.create_dir(opts.RUN_PATH) util.create_dir(opts.IMG_OUT_PATH) util.create_dir(opts.MODEL_OUT_PATH) # Copy opts.py and model.py to opts.RUN_PATH as a record shutil.copy2('opts.py', opts.RUN_PATH) shutil.copy2('model.py', opts.RUN_PATH) shutil.copy2('train.py', opts.RUN_PATH) # Instantiate the models G = Generator(opts.EMBED_SIZE, num_classes).to(opts.DEVICE) G_optimizer = torch.optim.Adam(G.parameters(), lr=opts.ADAM_LR, betas=opts.ADAM_B) D = Discriminator(num_classes).to(opts.DEVICE) D_optimizer = torch.optim.Adam(D.parameters(), lr=opts.ADAM_LR, betas=opts.ADAM_B) if opts.MODEL_PATH is None: start_iepoch, start_ibatch = 0, 0 else: print('Attempting to resume training using model in %s...' % opts.MODEL_PATH) start_iepoch, start_ibatch = load_state_dicts(opts.MODEL_PATH, G, G_optimizer, D, D_optimizer) for iepoch in range(opts.NUM_EPOCHS): for ibatch, data_batch in enumerate(data_loader): # To try to resume training, just continue if iepoch and ibatch are less than their starts if iepoch < start_iepoch or (iepoch == start_iepoch and ibatch < start_ibatch): if iepoch % opts.INTV_PRINT_LOSS == 0 and not ibatch: print('Skipping epoch %d...' % iepoch) continue recipe_ids, recipe_embs, img_ids, imgs, classes, noisy_real, noisy_fake = data_batch # Make sure we're not training on validation or test data! if opts.SAFETY_MODE: for recipe_id in recipe_ids: assert data.get_recipe_split_index(recipe_id) == 0 batch_size, recipe_embs, imgs, classes, classes_one_hot = util.get_variables( recipe_ids, recipe_embs, img_ids, imgs, classes, num_classes) noisy_real, noisy_fake = util.get_variables2( noisy_real, noisy_fake) # Adversarial ground truths all_real = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False).to(opts.DEVICE) all_fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False).to(opts.DEVICE) # Train Generator for _ in range(opts.NUM_UPDATE_G): G_optimizer.zero_grad() imgs_gen = G(recipe_embs, classes_one_hot) fake_probs = D(imgs_gen, classes_one_hot) G_BCE_loss = BCELoss(fake_probs, all_real) G_MSE_loss = MSELoss(imgs_gen, imgs) G_loss = opts.A_BCE * G_BCE_loss + opts.A_MSE * G_MSE_loss G_loss.backward() G_optimizer.step() # Train Discriminator for _ in range(opts.NUM_UPDATE_D): D_optimizer.zero_grad() fake_probs = D(imgs_gen.detach(), classes_one_hot) real_probs = D(imgs, classes_one_hot) D_loss = ( BCELoss(fake_probs, noisy_fake if opts.NOISY_LABELS else all_fake) + BCELoss(real_probs, noisy_real if opts.NOISY_LABELS else all_real)) / 2 D_loss.backward() D_optimizer.step() if iepoch % opts.INTV_PRINT_LOSS == 0 and not ibatch: print_loss(G_BCE_loss, G_MSE_loss, D_loss, iepoch) if iepoch % opts.INTV_SAVE_IMG == 0 and not ibatch: # Save a training image get_img_gen(data, 0, G, iepoch, opts.IMG_OUT_PATH) # Save a validation image get_img_gen(data, 1, G, iepoch, opts.IMG_OUT_PATH) if iepoch % opts.INTV_SAVE_MODEL == 0 and not ibatch: print('Saving model...') save_model(G, G_optimizer, D, D_optimizer, iepoch, opts.MODEL_OUT_PATH) save_model(G, G_optimizer, D, D_optimizer, 'FINAL', opts.MODEL_OUT_PATH) print('\a') # Ring the bell to alert the human