def gan_loss(G, D, imgs, recipe_embs, noisy_real, noisy_fake): batch_size = imgs.shape[0] z = torch.randn(batch_size, opts.LATENT_SIZE).to(opts.DEVICE) imgs_gen = G(z, recipe_embs) fake_probs = D(imgs_gen.detach(), recipe_embs) real_probs = D(imgs, recipe_embs) D_loss = BCELoss(fake_probs, noisy_fake) + BCELoss(real_probs, noisy_real) all_real = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False).to(opts.DEVICE) G_loss = BCELoss(fake_probs, all_real) return D_loss, G_loss
def get_variables(recipe_ids, recipe_embs, img_ids, imgs, classes, num_classes): # Set up Variables batch_size = imgs.shape[0] recipe_embs = Variable(recipe_embs.type(FloatTensor)).to(opts.DEVICE) imgs = Variable(imgs.type(FloatTensor)).to(opts.DEVICE) classes = Variable(classes.type(LongTensor)).to(opts.DEVICE) classes_one_hot = Variable( FloatTensor(batch_size, num_classes).zero_().scatter_(1, classes.view(-1, 1), 1)).to(opts.DEVICE) return batch_size, recipe_embs, imgs, classes, classes_one_hot
def main(): # Load the data data = GANstronomyDataset(opts.DATA_PATH, split=opts.TVT_SPLIT) data.set_split_index(0) data_loader = torch.utils.data.DataLoader(data, batch_size=opts.BATCH_SIZE, shuffle=True) num_classes = data.num_classes() # Make the output directory util.create_dir(opts.RUN_PATH) util.create_dir(opts.IMG_OUT_PATH) util.create_dir(opts.MODEL_OUT_PATH) # Copy opts.py and model.py to opts.RUN_PATH as a record shutil.copy2('opts.py', opts.RUN_PATH) shutil.copy2('model.py', opts.RUN_PATH) shutil.copy2('train.py', opts.RUN_PATH) # Instantiate the models G = Generator(opts.EMBED_SIZE, num_classes).to(opts.DEVICE) G_optimizer = torch.optim.Adam(G.parameters(), lr=opts.ADAM_LR, betas=opts.ADAM_B) D = Discriminator(num_classes).to(opts.DEVICE) D_optimizer = torch.optim.Adam(D.parameters(), lr=opts.ADAM_LR, betas=opts.ADAM_B) if opts.MODEL_PATH is None: start_iepoch, start_ibatch = 0, 0 else: print('Attempting to resume training using model in %s...' % opts.MODEL_PATH) start_iepoch, start_ibatch = load_state_dicts(opts.MODEL_PATH, G, G_optimizer, D, D_optimizer) for iepoch in range(opts.NUM_EPOCHS): for ibatch, data_batch in enumerate(data_loader): # To try to resume training, just continue if iepoch and ibatch are less than their starts if iepoch < start_iepoch or (iepoch == start_iepoch and ibatch < start_ibatch): if iepoch % opts.INTV_PRINT_LOSS == 0 and not ibatch: print('Skipping epoch %d...' % iepoch) continue recipe_ids, recipe_embs, img_ids, imgs, classes, noisy_real, noisy_fake = data_batch # Make sure we're not training on validation or test data! if opts.SAFETY_MODE: for recipe_id in recipe_ids: assert data.get_recipe_split_index(recipe_id) == 0 batch_size, recipe_embs, imgs, classes, classes_one_hot = util.get_variables( recipe_ids, recipe_embs, img_ids, imgs, classes, num_classes) noisy_real, noisy_fake = util.get_variables2( noisy_real, noisy_fake) # Adversarial ground truths all_real = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False).to(opts.DEVICE) all_fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False).to(opts.DEVICE) # Train Generator for _ in range(opts.NUM_UPDATE_G): G_optimizer.zero_grad() imgs_gen = G(recipe_embs, classes_one_hot) fake_probs = D(imgs_gen, classes_one_hot) G_BCE_loss = BCELoss(fake_probs, all_real) G_MSE_loss = MSELoss(imgs_gen, imgs) G_loss = opts.A_BCE * G_BCE_loss + opts.A_MSE * G_MSE_loss G_loss.backward() G_optimizer.step() # Train Discriminator for _ in range(opts.NUM_UPDATE_D): D_optimizer.zero_grad() fake_probs = D(imgs_gen.detach(), classes_one_hot) real_probs = D(imgs, classes_one_hot) D_loss = ( BCELoss(fake_probs, noisy_fake if opts.NOISY_LABELS else all_fake) + BCELoss(real_probs, noisy_real if opts.NOISY_LABELS else all_real)) / 2 D_loss.backward() D_optimizer.step() if iepoch % opts.INTV_PRINT_LOSS == 0 and not ibatch: print_loss(G_BCE_loss, G_MSE_loss, D_loss, iepoch) if iepoch % opts.INTV_SAVE_IMG == 0 and not ibatch: # Save a training image get_img_gen(data, 0, G, iepoch, opts.IMG_OUT_PATH) # Save a validation image get_img_gen(data, 1, G, iepoch, opts.IMG_OUT_PATH) if iepoch % opts.INTV_SAVE_MODEL == 0 and not ibatch: print('Saving model...') save_model(G, G_optimizer, D, D_optimizer, iepoch, opts.MODEL_OUT_PATH) save_model(G, G_optimizer, D, D_optimizer, 'FINAL', opts.MODEL_OUT_PATH) print('\a') # Ring the bell to alert the human