def train_vae(epoch, model, train_loader, cond=False):
    """
    Train a VAE or CVAE!

    Inputs:
    - epoch: Current epoch number 
    - model: VAE model object
    - train_loader: PyTorch Dataloader object that contains our training data
    - cond: Boolean value representing whether we're training a VAE or 
    Conditional VAE 
    """
    model.train()
    train_loss = 0
    num_classes = 10
    loss = None
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    for batch_idx, (data, labels) in enumerate(train_loader):
        data = data.to(device='cuda:0')
        if cond:
            one_hot_vec = one_hot(labels, num_classes).to(device='cuda')
            recon_batch, mu, logvar = model(data, one_hot_vec)
        else:
            recon_batch, mu, logvar = model(data)
        optimizer.zero_grad()
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.data
        optimizer.step()
    print('Train Epoch: {} \tLoss: {:.6f}'.format(epoch, loss.data))
예제 #2
0
def validate():
    model.eval()
    valid_loss = 0
    with torch.no_grad():
        for batch_idx, data in enumerate(valid_loader):
            x = data.to(device)
            recon_x, mu, logvar = model(x)
            loss = loss_function(x, recon_x, mu, logvar)
            valid_loss += loss.item()
    return valid_loss
예제 #3
0
def train():
    model.train()
    train_loss = 0
    for batch_idx, data in enumerate(train_loader):
        x = data.to(device)
        optimizer.zero_grad()
        recon_x, mu, logvar = model(x)
        loss = loss_function(x, recon_x, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    return train_loss
예제 #4
0
def train():
  epochs = 300
  batch_size = 128
  vae = TorchVAE()
  vae.cuda()
  vae.train()
  vae.weight_init(mean=0, std=0.02)
  try:
    vae.load_state_dict(torch.load(vae_data))
    print("loaded state")
  except:
    pass

  lr = 0.0001
  vae_optimizer = torch.optim.Adam(vae.parameters(), lr=lr, betas=(0.5, 0.999), weight_decay=1e-5)
  for epoch in range(epochs):
    vae.train()
    data_train = load_data()
    print("Train set size:", len(data_train))
    random.shuffle(data_train)
    batches = []
    for i in range(0, len(data_train) - batch_size, batch_size):
      batches.append(np.asarray(data_train[i:i + batch_size]))
    i = 0
    for x in batches:
      batch = process_batch(x)
      vae.train()
      vae.zero_grad()
      rec, mu, logvar = vae(batch)

      loss_re, loss_kl = loss_function(rec, batch, mu, logvar)
      (loss_re + loss_kl).backward()
      vae_optimizer.step()

      i += 1
      if i % 32 == 0:
        with torch.no_grad():
          vae.eval()
          x_rec, _, _ = vae(batch[:8])
          resultsample = torch.cat([batch[:8], x_rec]) * 0.5 + 0.5
          resultsample = resultsample.cpu()
          ndarr = make_grid(resultsample.view(-1, 3, 128, 128)).mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
          cv2.imshow("rec", cv2.cvtColor(ndarr, cv2.COLOR_RGB2BGR))
          cv2.waitKey(1)
      del batch
      torch.cuda.empty_cache()

    del batches
    del data_train
    torch.save(vae.state_dict(), vae_data)
  print("Training finish!... save training results")
예제 #5
0
def train_epoch(epoch, config, model, device, train_loader, optimizer, writer):
    model.train()

    for batch_idx, (im_path, sample) in enumerate(train_loader):
        props = sample["props"].to(device)
        in_pos = sample["in_pos"].to(device)
        out_pos = sample["out_pos"].to(device)
        abs_prob = sample["abs"].to(device)
        
        im = image_generate(im_path, config.im_size)
        
        optimizer.zero_grad()
        recon_pos, recon_abs, mu, logvar = model(props, im.to(device), in_pos, out_pos, is_training=True)
        loss_total, losses = loss_function(recon_pos, out_pos, recon_abs, abs_prob, mu, logvar, config)

        loss_total.backward()
        optimizer.step()

        if(batch_idx % 750 == 0):
            day_time = datetime.datetime.now()
            n_data = batch_idx * config.loader_args["batch_size"]
            print(f"{day_time} -- Log: data {n_data} / 9600000")

        # Logging with TensorboardX
        writer.add_scalar("train/total_loss", loss_total, (epoch-1) * len(train_loader) + batch_idx)
        writer.add_scalars("train/loss_weighted",
                           {
                               "latent": losses["latent"] * config.loss_weight_latent,
                               "position": losses["pos"] * config.loss_weight_pos,
                               "absorption": losses["abs"] * config.loss_weight_abs
                           },
                           (epoch - 1) * len(train_loader) + batch_idx)
        writer.add_scalars("train/loss",
                           {
                               "latent": (losses["latent"]),
                               "position": (losses["pos"]),
                               "absorption": (losses["abs"])
                           },
                           (epoch - 1) * len(train_loader) + batch_idx)
예제 #6
0
def test(epoch, config, model, device, test_loader, writer):
    day_time = datetime.datetime.now()
    print(f"{day_time} -- Test Start")
    model.eval()
    test_loss_total = 0
    test_loss_latent = 0
    test_loss_pos = 0
    test_loss_abs = 0

    im_show = True

    cnt_test = 0

    with torch.no_grad():
        for im_path, sample in test_loader:
            props = sample["props"].to(device)
            in_pos = sample["in_pos"].to(device)
            out_pos = sample["out_pos"].to(device)
            abs_prob = sample["abs"].to(device)

            im = image_generate(im_path, config.im_size)

            recon_pos, recon_abs, mu, logvar = model(props, im.to(device), in_pos, out_pos)
            loss_total, losses = loss_function(recon_pos, out_pos, recon_abs, abs_prob, mu, logvar, config)

            test_loss_total += loss_total
            test_loss_latent += losses["latent"]
            test_loss_pos += losses["pos"]
            test_loss_abs += losses["abs"]

            cnt_test += 1

            if(im_show):
                print("recon pos diff: " + str(recon_pos[0:5, :] - in_pos[0:5, :]))
                print("ref pos diff: " + str(out_pos[0:5, :] - in_pos[0:5, :]))
                print("recon_abs: " + str(recon_abs[0:5]))
                print("ref_abs: " + str(abs_prob[0:5]))
                im_show = False

    test_loss_latent /= cnt_test
    test_loss_pos /= cnt_test
    test_loss_abs /= cnt_test

    writer.add_scalar("test/total_loss", test_loss_total, epoch)
    writer.add_scalars("test/loss_weighted",
                       {
                           "latent": test_loss_latent * config.loss_weight_latent,
                           "position": test_loss_pos * config.loss_weight_pos,
                           "absorption": test_loss_abs * config.loss_weight_abs
                       },
                       epoch)
    writer.add_scalars("test/loss",
                       {
                           "latent": (test_loss_latent),
                           "position": (test_loss_pos),
                           "absorption": (test_loss_abs)
                       },
                       epoch)

    day_time = datetime.datetime.now()
    print(f"{day_time} -- Test End")
예제 #7
0
def main6():
    # vae test
    doc = Document(content=[[
        'to', 'the', 'editor', 're', 'for', 'women', 'worried', 'about',
        'fertility', 'egg', 'bank', 'is', 'a', 'new', 'option', 'sept', '00',
        'imagine', 'my', 'joy', 'in', 'reading', 'the', 'morning',
        'newspapers', 'on', 'the', 'day', 'of', 'my', '00th', 'birthday',
        'and', 'finding', 'not', 'one', 'but', 'two', 'articles', 'on', 'how',
        'women', 's', 'fertility', 'drops', 'off', 'precipitously', 'after',
        'age', '00'
    ], [
        'one', 'in', 'the', 'times', 'and', 'one', 'in', 'another', 'newspaper'
    ], ['i', 'sense', 'a', 'conspiracy', 'here'],
                            [
                                'have', 'you', 'been', 'talking', 'to', 'my',
                                'mother', 'in', 'law'
                            ], ['laura', 'heymann', 'washington']],
                   summary=[[
                       'laura', 'heymann', 'letter', 'on', 'sept', '00',
                       'article', 'about', 'using', 'egg', 'bank', 'to',
                       'prolong', 'fertility', 'expresses', 'ironic', 'humor',
                       'about', 'her', 'age', 'and', 'chances', 'of',
                       'becoming', 'pregnant'
                   ]],
                   label=[0.01] * 100,
                   label_idx=[0.01] * 100)
    torch.manual_seed(233)
    torch.cuda.set_device(0)
    args = get_args()
    if args.data == "nyt":
        vocab_file = "/home/ml/lyu40/PycharmProjects/data/nyt/lda_domains/preprocessed/vocab_100d.p"
        with open(vocab_file, "rb") as f:
            vocab = pickle.load(f, encoding='latin1')
    else:
        vocab_file = '/home/ml/ydong26/data/CNNDM/CNN_DM_pickle_data/vocab_100d.p'
        with open(vocab_file, "rb") as f:
            vocab = pickle.load(f, encoding='latin1')
    config = Config(
        vocab_size=vocab.embedding.shape[0],
        embedding_dim=vocab.embedding.shape[1],
        category_size=args.category_size,
        category_dim=50,
        word_input_size=100,
        sent_input_size=2 * args.hidden,
        word_GRU_hidden_units=args.hidden,
        sent_GRU_hidden_units=args.hidden,
        pretrained_embedding=vocab.embedding,
        word2id=vocab.w2i,
        id2word=vocab.i2w,
    )
    model = VAE(config)

    if torch.cuda.is_available():
        model.cuda()
    train_loss = 0
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    x = prepare_data(
        doc, vocab.w2i
    )  # list of tokens ex.x=[[1,2,1],[1,1]] x = Variable(torch.from_numpy(x)).cuda()
    sents = Variable(torch.from_numpy(x)).cuda()
    optimizer.zero_grad()
    loss = 0
    for sent in sents:
        recon_batch, mu, logvar = model(sent.float())
        loss += loss_function(recon_batch, sent, mu, logvar)
    loss.backward()
    train_loss += loss.data[0]
    optimizer.step()