コード例 #1
0
def main():
    if len(sys.argv) < 7:
        print(
            'Usage: python3 sample.py [MODEL_PATH] [DATA_PATH] [SPLIT_INDEX] [OUT_PATH] [RECIPE_ID] [NUM_SAMPLES]'
        )
        exit()
    model_path = os.path.abspath(sys.argv[1])
    data_path = os.path.abspath(sys.argv[2])
    split_index = int(sys.argv[3])
    out_path = os.path.abspath(sys.argv[4])
    recipe_id = sys.argv[5]
    num_samples = int(sys.argv[6])

    util.create_dir(out_path)
    saved_model = torch.load(model_path)

    data = GANstronomyDataset(data_path, split=opts.TVT_SPLIT)
    data.set_split_index(split_index)
    data_loader = DataLoader(data,
                             batch_size=1,
                             shuffle=False,
                             sampler=SequentialSampler(data))

    num_classes = data.num_classes()
    G = Generator(opts.LATENT_SIZE, opts.EMBED_SIZE).to(opts.DEVICE)
    G.load_state_dict(saved_model['G_state_dict'])
    G.eval()

    embs = None

    for ibatch, data_batch in enumerate(data_loader):
        with torch.no_grad():
            recipe_ids, recipe_embs, img_ids, imgs, classes, _, _ = data_batch
            batch_size, recipe_embs, imgs = util.get_variables3(
                recipe_ids, recipe_embs, img_ids, imgs)
            if recipe_ids[0] == recipe_id:
                embs = recipe_embs
                break

    assert embs is not None

    z = torch.randn(num_samples, opts.LATENT_SIZE).to(opts.DEVICE)
    imgs_gen = G(z, embs.expand(num_samples, opts.EMBED_SIZE)).detach()
    for i in range(num_samples):
        img_gen = imgs_gen[i]
        save_results(out_path, recipe_id, img_gen, i)
コード例 #2
0
ファイル: interp.py プロジェクト: micklexqg/gan-stronomy
def main():
    if len(sys.argv) < 7:
        print('Usage: python3 interp.py [MODEL_PATH] [DATA_PATH] [SPLIT_INDEX] [OUT_PATH] [RECIPE_ID] [NUM_DIV]')
        exit()
    model_path = os.path.abspath(sys.argv[1])
    data_path = os.path.abspath(sys.argv[2])
    split_index = int(sys.argv[3])
    out_path = os.path.abspath(sys.argv[4])
    recipe_id = sys.argv[5]
    num_div = int(sys.argv[6])
    
    util.create_dir(out_path)
    saved_model = torch.load(model_path)

    data = GANstronomyDataset(data_path, split=opts.TVT_SPLIT)
    data.set_split_index(split_index)
    data_loader = DataLoader(data, batch_size=1, shuffle=False, sampler=SequentialSampler(data))

    num_classes = data.num_classes()
    G = Generator(opts.LATENT_SIZE, opts.EMBED_SIZE).to(opts.DEVICE)
    G.load_state_dict(saved_model['G_state_dict'])
    G.eval()

    embs = None

    for ibatch, data_batch in enumerate(data_loader):
        with torch.no_grad():
            recipe_ids, recipe_embs, img_ids, imgs, classes, _, _ = data_batch
            batch_size, recipe_embs, imgs  = util.get_variables3(recipe_ids, recipe_embs, img_ids, imgs)
            if recipe_ids[0] == recipe_id:
                embs = recipe_embs
                break

    assert embs is not None
    z1 = torch.randn(1, opts.LATENT_SIZE).to(opts.DEVICE)
    z2 = torch.randn(1, opts.LATENT_SIZE).to(opts.DEVICE)
    for a in np.linspace(0.0, 1.0, num_div + 1):
        a = torch.tensor(a, dtype=torch.float)
        a = Variable(a.type(FloatTensor)).to(opts.DEVICE)
        z = (1.0 - a) * z1 + a * z2
        img_gen = G(z, embs).detach()[0]
        save_results(out_path, recipe_id, img_gen, a)
コード例 #3
0
def main():
    if len(sys.argv) < 5:
        print(
            'Usage: python3 test.py [MODEL_PATH] [DATA_PATH] [SPLIT_INDEX] [OUT_PATH]'
        )
        exit()
    model_path = os.path.abspath(sys.argv[1])
    data_path = os.path.abspath(sys.argv[2])
    split_index = int(sys.argv[3])
    out_path = os.path.abspath(sys.argv[4])

    util.create_dir(out_path)
    saved_model = torch.load(model_path)

    data = GANstronomyDataset(data_path, split=opts.TVT_SPLIT)
    data.set_split_index(split_index)
    data_loader = DataLoader(data,
                             batch_size=opts.BATCH_SIZE,
                             shuffle=False,
                             sampler=SequentialSampler(data))

    G = Generator(opts.LATENT_SIZE, opts.EMBED_SIZE).to(opts.DEVICE)
    G.load_state_dict(saved_model['G_state_dict'])
    G.eval()

    all_ingrs = util.load_ingredients()

    for ibatch, data_batch in enumerate(data_loader):
        with torch.no_grad():
            recipe_ids, recipe_embs, img_ids, imgs, classes, _, _ = data_batch
            batch_size, recipe_embs, imgs, = util.get_variables3(
                recipe_ids, recipe_embs, img_ids, imgs)
            z = torch.randn(batch_size, opts.LATENT_SIZE).to(opts.DEVICE)
            imgs_gen = G(z, recipe_embs)
            imgs, imgs_gen = imgs.detach(), imgs_gen.detach()
            for iexample in range(batch_size):
                save_results(all_ingrs, imgs[iexample], imgs_gen[iexample],
                             img_ids[iexample], recipe_ids[iexample], out_path)
コード例 #4
0
def main():
    if len(sys.argv) < 3:
        print('Usage: python3 score.py [MODEL_PATH] [DATA_PATH]')
        exit()
    model_path = os.path.abspath(sys.argv[1])
    data_path = os.path.abspath(sys.argv[2])

    saved_model = torch.load(model_path)

    data = GANstronomyDataset(data_path, split=opts.TVT_SPLIT)

    G = Generator(opts.LATENT_SIZE, opts.EMBED_SIZE).to(opts.DEVICE)
    G.load_state_dict(saved_model['G_state_dict'])
    G.eval()

    val_imgs, val_imgs_gen = get_val_imgs(G, data)
    test_imgs = get_test_imgs(data)
    print('FID(test_real, val_real): %f' % util.get_fid(test_imgs, val_imgs))
    print('FID(test_real, val_fake): %f' %
          util.get_fid(test_imgs, val_imgs_gen))
コード例 #5
0
def main():
    # Load the data
    data = GANstronomyDataset(opts.DATA_PATH, split=opts.TVT_SPLIT)
    data.set_split_index(0)
    data_loader = torch.utils.data.DataLoader(data,
                                              batch_size=opts.BATCH_SIZE,
                                              shuffle=True)
    num_classes = data.num_classes()

    # Make the output directory
    util.create_dir(opts.RUN_PATH)
    util.create_dir(opts.IMG_OUT_PATH)
    util.create_dir(opts.MODEL_OUT_PATH)

    # Copy opts.py and model.py to opts.RUN_PATH as a record
    shutil.copy2('opts.py', opts.RUN_PATH)
    shutil.copy2('model.py', opts.RUN_PATH)
    shutil.copy2('train.py', opts.RUN_PATH)

    # Instantiate the models
    G = Generator(opts.EMBED_SIZE, num_classes).to(opts.DEVICE)
    G_optimizer = torch.optim.Adam(G.parameters(),
                                   lr=opts.ADAM_LR,
                                   betas=opts.ADAM_B)

    D = Discriminator(num_classes).to(opts.DEVICE)
    D_optimizer = torch.optim.Adam(D.parameters(),
                                   lr=opts.ADAM_LR,
                                   betas=opts.ADAM_B)

    if opts.MODEL_PATH is None:
        start_iepoch, start_ibatch = 0, 0
    else:
        print('Attempting to resume training using model in %s...' %
              opts.MODEL_PATH)
        start_iepoch, start_ibatch = load_state_dicts(opts.MODEL_PATH, G,
                                                      G_optimizer, D,
                                                      D_optimizer)

    for iepoch in range(opts.NUM_EPOCHS):
        for ibatch, data_batch in enumerate(data_loader):
            # To try to resume training, just continue if iepoch and ibatch are less than their starts
            if iepoch < start_iepoch or (iepoch == start_iepoch
                                         and ibatch < start_ibatch):
                if iepoch % opts.INTV_PRINT_LOSS == 0 and not ibatch:
                    print('Skipping epoch %d...' % iepoch)
                continue

            recipe_ids, recipe_embs, img_ids, imgs, classes, noisy_real, noisy_fake = data_batch

            # Make sure we're not training on validation or test data!
            if opts.SAFETY_MODE:
                for recipe_id in recipe_ids:
                    assert data.get_recipe_split_index(recipe_id) == 0

            batch_size, recipe_embs, imgs, classes, classes_one_hot = util.get_variables(
                recipe_ids, recipe_embs, img_ids, imgs, classes, num_classes)
            noisy_real, noisy_fake = util.get_variables2(
                noisy_real, noisy_fake)

            # Adversarial ground truths
            all_real = Variable(FloatTensor(batch_size, 1).fill_(1.0),
                                requires_grad=False).to(opts.DEVICE)
            all_fake = Variable(FloatTensor(batch_size, 1).fill_(0.0),
                                requires_grad=False).to(opts.DEVICE)

            # Train Generator
            for _ in range(opts.NUM_UPDATE_G):
                G_optimizer.zero_grad()
                imgs_gen = G(recipe_embs, classes_one_hot)

                fake_probs = D(imgs_gen, classes_one_hot)
                G_BCE_loss = BCELoss(fake_probs, all_real)
                G_MSE_loss = MSELoss(imgs_gen, imgs)
                G_loss = opts.A_BCE * G_BCE_loss + opts.A_MSE * G_MSE_loss
                G_loss.backward()
                G_optimizer.step()

            # Train Discriminator
            for _ in range(opts.NUM_UPDATE_D):
                D_optimizer.zero_grad()
                fake_probs = D(imgs_gen.detach(), classes_one_hot)
                real_probs = D(imgs, classes_one_hot)
                D_loss = (
                    BCELoss(fake_probs,
                            noisy_fake if opts.NOISY_LABELS else all_fake) +
                    BCELoss(real_probs,
                            noisy_real if opts.NOISY_LABELS else all_real)) / 2
                D_loss.backward()
                D_optimizer.step()

            if iepoch % opts.INTV_PRINT_LOSS == 0 and not ibatch:
                print_loss(G_BCE_loss, G_MSE_loss, D_loss, iepoch)
            if iepoch % opts.INTV_SAVE_IMG == 0 and not ibatch:
                # Save a training image
                get_img_gen(data, 0, G, iepoch, opts.IMG_OUT_PATH)
                # Save a validation image
                get_img_gen(data, 1, G, iepoch, opts.IMG_OUT_PATH)
            if iepoch % opts.INTV_SAVE_MODEL == 0 and not ibatch:
                print('Saving model...')
                save_model(G, G_optimizer, D, D_optimizer, iepoch,
                           opts.MODEL_OUT_PATH)

    save_model(G, G_optimizer, D, D_optimizer, 'FINAL', opts.MODEL_OUT_PATH)
    print('\a')  # Ring the bell to alert the human