Example #1
0
def train():
    train_loss = []
    for batch_idx, (x, label) in enumerate(train_loader):
        start_time = time.time()
        x = x.to(DEVICE)
        label = label.to(DEVICE)

        # Get the latent codes for image x
        latents, _ = autoencoder.encode(x)

        # Train PixelCNN with latent codes
        latents = latents.detach()
        logits = model(latents, label)
        logits = logits.permute(0, 2, 3, 1).contiguous()

        loss = criterion(logits.view(-1, K), latents.view(-1))

        opt.zero_grad()
        loss.backward()
        opt.step()

        train_loss.append(to_scalar(loss))

        if (batch_idx + 1) % PRINT_INTERVAL == 0:
            print('\tIter: [{}/{} ({:.0f}%)]\tLoss: {} Time: {}'.format(
                batch_idx * len(x), len(train_loader.dataset),
                PRINT_INTERVAL * batch_idx / len(train_loader),
                np.asarray(train_loss)[-PRINT_INTERVAL:].mean(0),
                time.time() - start_time))
Example #2
0
def train(epoch):
    train_loss = []
    for batch_idx, (data, _) in enumerate(train_loader):
        start_time = time.time()
        x = Variable(data, requires_grad=False).cuda()

        opt.zero_grad()

        x_tilde, z_e_x, z_q_x = model(x)
        z_q_x.retain_grad()

        loss_1 = F.binary_cross_entropy(x_tilde, x)
        # loss_1 = F.l1_loss(x_tilde, x)
        loss_1.backward(retain_graph=True)
        model.embedding.zero_grad()
        z_e_x.backward(z_q_x.grad, retain_graph=True)

        loss_2 = F.mse_loss(z_q_x, z_e_x.detach())
        loss_2.backward(retain_graph=True)

        loss_3 = 0.25 * F.mse_loss(z_e_x, z_q_x.detach())
        loss_3.backward()
        opt.step()

        train_loss.append(to_scalar([loss_1, loss_2]))

        print 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {} Time: {}'.format(
            epoch, batch_idx * len(data), len(train_loader.dataset),
            100. * batch_idx / len(train_loader),
            np.asarray(train_loss).mean(0),
            time.time() - start_time)
Example #3
0
def train(epoch):
    train_loss = []
    for batch_idx, (data, _) in enumerate(train_loader):
        start_time = time.time()

        if data.size(0) != 64:
            continue
        x_real = Variable(data, requires_grad=False).cuda()

        netD.zero_grad()
        # train with real
        D_real = netD(x_real)
        D_real = D_real.mean()
        D_real.backward(mone)

        # train with fake
        z = Variable(torch.randn(64, 128)).cuda()
        x_fake = Variable(netG(z).data)
        D_fake = netD(x_fake)
        D_fake = D_fake.mean()
        D_fake.backward(one)

        # train with gradient penalty
        gradient_penalty = calc_gradient_penalty(netD, x_real.data,
                                                 x_fake.data)
        gradient_penalty.backward()

        D_cost = D_fake - D_real + gradient_penalty
        Wasserstein_D = D_real - D_fake
        optimizerD.step()

        if (batch_idx + 1) % 6 == 0:
            for p in netD.parameters():
                p.requires_grad = False  # to avoid computation

            netG.zero_grad()
            z = Variable(torch.randn(64, 128)).cuda()
            x_fake = netG(z)
            D_fake = netD(x_fake)
            D_fake = D_fake.mean()
            D_fake.backward(mone)
            G_cost = -D_fake
            optimizerG.step()

            for p in netD.parameters():  # reset requires_grad
                p.requires_grad = True  # they are set to False below in netG update

            train_loss.append(to_scalar([D_cost, G_cost, Wasserstein_D]))
            print 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {} Time: {}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                np.asarray(train_loss).mean(0),
                time.time() - start_time)
def test():
    start_time = time.time()
    val_loss = []
    for batch_idx, (x, _) in enumerate(test_loader):
        x = x.to(DEVICE)
        x_tilde, z_e_x, z_q_x = model(x)
        loss_recons = F.mse_loss(x_tilde, x)
        loss_vq = F.mse_loss(z_q_x, z_e_x.detach())
        val_loss.append(to_scalar([loss_recons, loss_vq]))

    print('\nValidation Completed!\tLoss: {} Time: {:5.3f}'.format(
        np.asarray(val_loss).mean(0),
        time.time() - start_time))
    return np.asarray(val_loss).mean(0)
Example #5
0
def test():
    start_time = time.time()
    val_loss = []
    with torch.no_grad():
        for batch_idx, (x, label) in enumerate(test_loader):
            x = x.to(DEVICE)
            label = label.to(DEVICE)

            latents, _ = autoencoder.encode(x)
            logits = model(latents.detach(), label)
            logits = logits.permute(0, 2, 3, 1).contiguous()
            loss = criterion(logits.view(-1, K), latents.view(-1))
            val_loss.append(to_scalar(loss))

    print('Validation Completed!\tLoss: {} Time: {}'.format(
        np.asarray(val_loss).mean(0),
        time.time() - start_time))
    return np.asarray(val_loss).mean(0)
def train():
    train_loss = []
    for batch_idx, (x, _) in enumerate(train_loader):
        start_time = time.time()
        x = x.to(DEVICE)

        opt.zero_grad()

        x_tilde, z_e_x, z_q_x = model(x)
        z_q_x.retain_grad()

        loss_recons = F.mse_loss(x_tilde, x)
        loss_recons.backward(retain_graph=True)

        # Straight-through estimator
        z_e_x.backward(z_q_x.grad, retain_graph=True)

        # Vector quantization objective
        model.codebook.zero_grad()
        loss_vq = F.mse_loss(z_q_x, z_e_x.detach())
        loss_vq.backward(retain_graph=True)

        # Commitment objective
        loss_commit = LAMDA * F.mse_loss(z_e_x, z_q_x.detach())
        loss_commit.backward()
        opt.step()

        N = x.numel()
        nll = Normal(x_tilde, torch.ones_like(x_tilde)).log_prob(x)
        log_px = nll.sum() / N + np.log(128) - np.log(K * 2)
        log_px /= np.log(2)

        train_loss.append([log_px.item()] + to_scalar([loss_recons, loss_vq]))

        if (batch_idx + 1) % PRINT_INTERVAL == 0:
            print('\tIter [{}/{} ({:.0f}%)]\tLoss: {} Time: {}'.format(
                batch_idx * len(x), len(train_loader.dataset),
                PRINT_INTERVAL * batch_idx / len(train_loader),
                np.asarray(train_loss)[-PRINT_INTERVAL:].mean(0),
                time.time() - start_time))
Example #7
0
    noise_process.reset()
    avg_reward = 0
    for t in count(1):
        a_t = model.sample_action(s_t)
        a_t = a_t + noise_process.sample()

        s_tp1, r_t, done, info = env.step(a_t)

        model.buffer.add(s_t, a_t, r_t, s_tp1, float(done == False))
        avg_reward += r_t

        if done:
            break
        else:
            s_t = s_tp1

        if model.buffer.len >= cf.replay_start_size:
            _loss_a, _loss_c = model.train_batch()
            losses.append(to_scalar([_loss_a, _loss_c]))

    if len(losses) > 0:
        total_timesteps += t
        avg_loss_a, avg_loss_c = np.asarray(losses)[-100:].mean(0)
        print('Episode {}: actor loss: {} critic loss: {}\
            total_reward: {} timesteps: {} tot_timesteps: {}'.format(
            epi, avg_loss_a, avg_loss_c, avg_reward, t, total_timesteps))

    if (epi + 1) % 200 == 0:
        model.save_models()
print('Completed training!')