Exemplo n.º 1
0
def train_ae(ae, dataset_size, iters=5000, batch_size=32):
    dataset = BoxDataset(dataset_size)
    ts = TimeSeries('Training ae', iters)
    opt_ae = optim.Adam(ae.parameters())
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    i = 0
    while i < iters:
        for i_batch, batch in enumerate(dataloader):
            i += 1
            if i > iters:
                break
            ae.train()
            ae.zero_grad()
            x, y, p = [o.cuda() for o in batch]
            x_hat = ae.forward(x)
            loss_aae = F.binary_cross_entropy(x_hat, y)

            ts.collect("Reconstruction AE loss", loss_aae)

            loss_aae.backward()

            opt_ae.step()
            ae.eval()
            ts.print_every(2)
Exemplo n.º 2
0
def train_lae(ae,
              lae,
              dataset,
              epochs=20,
              batch_size=128,
              print_every_seconds=2):
    ts = TimeSeries('Training ae', epochs * (len(dataset) // batch_size) + 1)
    opt = optim.Adam(lae.parameters(), lr=2e-2)
    dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=4)

    ae.eval()
    for epoch in range(epochs):
        for i_batch, batch in enumerate(dataloader):
            lae.train()
            x, m, p = [o.cuda() for o in batch]
            # y = y.type(torch.cuda.FloatTensor)
            # x = torch.cat((y, y, y), dim=1)
            z = ae.encoder.forward(x)

            z_hat = lae.forward(z)

            loss = F.mse_loss(z, z_hat)
            ts.collect("LAE loss", loss)

            opt.zero_grad()
            loss.backward()

            opt.step()
            lae.eval()
            ts.print_every(print_every_seconds)
Exemplo n.º 3
0
def train_ae(ae, dataset, iters=5000, batch_size=32, save_every=0, save_path=None, print_every_seconds=10):
    ts = TimeSeries('Training ae', iters)
    opt_ae = optim.Adam(ae.parameters(), lr=2e-4)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1)

    i = 0
    print_numbers = 0
    while i < iters:
        for i_batch, batch in enumerate(dataloader):
            i += 1
            if i > iters:
                break
            ae.train()
            x, m, p = [o.cuda() for o in batch]
            # y = y.type(torch.cuda.FloatTensor)
            # x = torch.cat((y, y, y), dim=1)
            x_hat = ae.forward(x)

            bootstrap_ratio = 4
            if bootstrap_ratio > 1:
                mse = torch.flatten((x_hat - x) ** 2)
                loss_aae = torch.mean(torch.topk(mse, mse.numel() // bootstrap_ratio)[0])
            else:
                loss_aae = F.mse_loss(x, x_hat)
            ts.collect("Reconstruction AE loss", loss_aae)

            opt_ae.zero_grad()
            loss_aae.backward()

            opt_ae.step()
            ae.eval()
            ts.print_every(print_every_seconds)
            if save_every != 0 and save_path is not None and i % save_every == 0:
                print_batch(x, x_hat, save_path)
Exemplo n.º 4
0
def train_latnet(ae,
                 latnet,
                 dataset,
                 epochs=20,
                 batch_size=128,
                 save_every=0,
                 save_path=None,
                 print_every_seconds=10,
                 transform_pose=None):
    ts = TimeSeries('Training ae', epochs * (len(dataset) // batch_size) + 1)
    opt = optim.Adam(latnet.parameters(), lr=2e-4)
    dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=1)

    ae.eval()
    for epoch in range(epochs):
        for i_batch, batch in enumerate(dataloader):
            latnet.train()
            x, m, p = [o.cuda() for o in batch]
            # y = y.type(torch.cuda.FloatTensor)
            # x = torch.cat((y, y, y), dim=1)
            z = ae.encoder.forward(x)
            if transform_pose:
                p = transform_pose(p)

            z_hat = latnet.forward(p)

            loss = F.mse_loss(z, z_hat)
            ts.collect("Reconstruction AE loss", loss)

            opt.zero_grad()
            loss.backward()

            opt.step()
            latnet.eval()
            ts.print_every(print_every_seconds)
Exemplo n.º 5
0
def main():
    # Create a 40x40 monochrome image autoencoder
    encoder = Encoder(latent_size)
    decoder = Decoder(latent_size)
    opt_encoder = optim.Adam(encoder.parameters())
    opt_decoder = optim.Adam(decoder.parameters())

    demo_batch, demo_targets = get_batch()
    vid = imutil.VideoLoop('autoencoder_reconstruction')
    ts = TimeSeries('Training', iters)

    # Train the network on the denoising autoencoder task
    for i in range(iters):
        encoder.train()
        decoder.train()

        opt_encoder.zero_grad()
        opt_decoder.zero_grad()

        batch_input, batch_target = get_batch()

        x = torch.Tensor(batch_input).cuda()
        y = torch.Tensor(batch_target).cuda()

        z = encoder(x)
        x_hat = decoder(z)
        loss = F.binary_cross_entropy(x_hat, y)
        ts.collect('Reconstruction loss', loss)

        loss.backward()
        opt_encoder.step()
        opt_decoder.step()

        encoder.eval()
        decoder.eval()

        if i % 25 == 0:
            filename = 'iter_{:06}_reconstruction.jpg'.format(i)
            x = torch.Tensor(demo_batch).cuda()
            z = encoder(x)
            x_hat = decoder(z)
            img = torch.cat([x[:4], x_hat[:4]])
            caption = 'iter {}: orig. vs reconstruction'.format(i)
            imutil.show(img,
                        filename=filename,
                        resize_to=(256, 512),
                        img_padding=10,
                        caption=caption,
                        font_size=8)
            vid.write_frame(img,
                            resize_to=(256, 512),
                            img_padding=10,
                            caption=caption,
                            font_size=12)
        ts.print_every(2)

    # Now examine the representation that the network has learned
    EVAL_FRAMES = 360
    z = torch.Tensor((1, latent_size)).cuda()
    ts_eval = TimeSeries('Evaluation', EVAL_FRAMES)
    vid_eval = imutil.VideoLoop('latent_space_traversal')
    for i in range(EVAL_FRAMES):
        theta = 2 * np.pi * (i / EVAL_FRAMES)
        box = build_box(20, 20, 10, theta)
        z = encoder(torch.Tensor(box).unsqueeze(0).cuda())[0]
        ts_eval.collect('Latent Dim 1', z[0])
        ts_eval.collect('Latent Dim 2', z[1])
        caption = "Theta={:.2f} Z_0={:.3f} Z_1={:.3f}".format(
            theta, z[0], z[1])
        pixels = imutil.show(box,
                             resize_to=(512, 512),
                             caption=caption,
                             font_size=12,
                             return_pixels=True)
        vid_eval.write_frame(pixels)
    print(ts)
Exemplo n.º 6
0
def higgins_metric_conv(simulator,
                        true_latent_dim,
                        encoder,
                        encoded_latent_dim,
                        batch_size=16,
                        train_iters=500):
    # Train a linear classifier using uniform randomly-generated pairs of images,
    # where the pair shares one generative factor in common.
    # Given the learned encodings of a pair, predict which factor is the same.
    linear_model = LinearClassifier(encoded_latent_dim, true_latent_dim)
    optimizer = torch.optim.Adam(linear_model.parameters())
    ts = TimeSeries('Computing Higgins Metric', train_iters)

    for train_iter in range(train_iters):

        def generate_equivariance_test_batch(y_labels):
            # Generate batch_size pairs
            random_factors = np.random.uniform(size=(batch_size, 2,
                                                     true_latent_dim))

            # Each pair of images has one factor in common
            for i in range(batch_size):
                y_idx = y_labels[i]
                random_factors[i][0][y_idx] = random_factors[i][1][y_idx]

            # For each pair, generate images with the simulator and encode the images
            images_left = simulator(random_factors[:, 0, :])
            images_right = simulator(random_factors[:, 1, :])

            # Now encode each pair and take their difference
            x_left = torch.FloatTensor(images_left).cuda()
            if len(x_left.shape) < 4:
                x_left = x_left.unsqueeze(1)
            x_right = torch.FloatTensor(images_right).cuda()
            if len(x_right.shape) < 4:
                x_right = x_right.unsqueeze(1)
            encoded_left = encoder(x_left)[0]
            encoded_right = encoder(x_right)[0]
            z_diff = torch.abs(encoded_left -
                               encoded_right).sum(dim=-1).sum(dim=-1)
            return z_diff.data.cpu().numpy()

        # For each pair, select a factor to set
        y_labels = np.random.randint(0, true_latent_dim, size=batch_size)
        L = 5
        z_diffs = np.zeros((L, batch_size, encoded_latent_dim))
        for l in range(L):
            z_diffs[l] = generate_equivariance_test_batch(y_labels)
        z_diff = np.mean(z_diffs, axis=0)
        z_diff = torch.FloatTensor(z_diff).cuda()

        # Now given z_diff, predict y_labels
        optimizer.zero_grad()
        target = torch.LongTensor(y_labels).cuda()
        logits = linear_model(z_diff)
        y_pred = torch.softmax(logits, dim=1).max(1, keepdim=True)[1]
        num_correct = y_pred.eq(target.view_as(y_pred)).sum().item()

        loss = nn.functional.nll_loss(torch.log_softmax(logits, dim=1), target)
        loss.backward()
        optimizer.step()

        # Track the training accuracy over time
        ts.collect('NLL Loss', loss)
        ts.collect('Train accuracy', num_correct / batch_size)
        ts.print_every(2)
        # Print accuracy for an extra big test batch at the end
        if train_iter == train_iters - 2:
            batch_size = 1000
    print(ts)
    print('Test Accuracy: {}'.format(num_correct / batch_size))

    return num_correct / batch_size
Exemplo n.º 7
0
        torch.zeros(transition.fc2.weight.shape).cuda())
    ts.collect('Sparsity loss', l1_loss)

    loss = pred_loss + l1_loss

    loss.backward()
    opt_encoder.step()
    opt_decoder.step()
    opt_transition.step()

    if i % 1000 == 0:
        imutil.show(target, caption='Target')
        imutil.show(predicted, caption='Predicted')
        scm = compute_causal_graph(transition, latent_size, num_actions)
        caption = 'Prediction Loss {:.03f}'.format(pred_loss)
        vid.write_frame(render_causal_graph(scm), caption=caption)
        demo_latent_dimensions(before[:9], encoder, decoder, transition,
                               latent_size, num_actions)
        demo_latent_video(before[:9],
                          encoder,
                          decoder,
                          transition,
                          latent_size,
                          num_actions,
                          epoch=i)
    ts.print_every(1)

vid.finish()

print(ts)
Exemplo n.º 8
0
def train(latent_dim, datasource, num_actions, num_rewards, encoder, decoder,
          reward_predictor, discriminator, transition):
    batch_size = args.batch_size
    train_iters = args.train_iters
    td_lambda_coef = args.td_lambda
    td_steps = args.td_steps
    truncate_bptt = args.truncate_bptt
    enable_td = args.latent_td
    enable_latent_overshooting = args.latent_overshooting
    learning_rate = args.learning_rate
    min_prediction_horizon = args.horizon_min
    max_prediction_horizon = args.horizon_max
    finetune_reward = args.finetune_reward
    REWARD_COEF = args.reward_coef
    ACTIVATION_L1_COEF = args.activation_l1_coef
    TRANSITION_L1_COEF = args.transition_l1_coef
    counterfactual_horizon = args.counterfactual_horizon
    start_iter = args.start_iter

    opt_enc = torch.optim.Adam(encoder.parameters(), lr=learning_rate)
    opt_dec = torch.optim.Adam(decoder.parameters(), lr=learning_rate)
    opt_trans = torch.optim.Adam(transition.parameters(), lr=learning_rate)
    opt_disc = torch.optim.Adam(discriminator.parameters(), lr=learning_rate)
    opt_pred = torch.optim.Adam(reward_predictor.parameters(),
                                lr=learning_rate)
    ts = TimeSeries('Training Model', train_iters, tensorboard=True)

    for train_iter in range(start_iter, train_iters + 1):
        if train_iter % ITERS_PER_VIDEO == 0:
            print('Evaluating networks...')
            evaluate(datasource,
                     encoder,
                     decoder,
                     transition,
                     discriminator,
                     reward_predictor,
                     latent_dim,
                     train_iter=train_iter)
            print('Saving networks to filesystem...')
            torch.save(transition.state_dict(), 'model-transition.pth')
            torch.save(encoder.state_dict(), 'model-encoder.pth')
            torch.save(decoder.state_dict(), 'model-decoder.pth')
            torch.save(discriminator.state_dict(), 'model-discriminator.pth')
            torch.save(reward_predictor.state_dict(),
                       'model-reward_predictor.pth')

        theta = train_iter / train_iters
        pred_delta = max_prediction_horizon - min_prediction_horizon
        prediction_horizon = min_prediction_horizon + int(pred_delta * theta)

        train_mode(
            [encoder, decoder, transition, discriminator, reward_predictor])

        # Train encoder/transition/decoder
        opt_enc.zero_grad()
        opt_dec.zero_grad()
        opt_trans.zero_grad()
        opt_pred.zero_grad()

        states, rewards, dones, actions = datasource.get_trajectories(
            batch_size, prediction_horizon)
        states = torch.Tensor(states).cuda()
        rewards = torch.Tensor(rewards).cuda()
        dones = torch.Tensor(dones.astype(int)).cuda()

        # Encode the initial state (using the first 3 frames)
        # Given t, t+1, t+2, encoder outputs the state at time t+1
        z = encoder(states[:, 0:3])
        z_orig = z.clone()

        # But wait, here's the problem: We can't use the encoded initial state as
        # an initial state of the dynamical system and expect the system to work
        # The dynamical system needs to have something like the Echo State Property
        # So the dynamical parts need to run long enough to reach a steady state

        # Keep track of "done" states to stop a training trajectory at the final time step
        active_mask = torch.ones(batch_size).cuda()

        loss = 0
        lo_loss = 0
        lo_z_set = {}
        # Given the state encoded at t=2, predict state at t=3, t=4, ...
        for t in range(1, prediction_horizon - 1):
            active_mask = active_mask * (1 - dones[:, t])

            # Predict reward
            expected_reward = reward_predictor(z)
            actual_reward = rewards[:, t]
            reward_difference = torch.mean(
                torch.mean(
                    (expected_reward - actual_reward)**2, dim=1) * active_mask)
            ts.collect('Rd Loss t={}'.format(t), reward_difference)
            loss += theta * REWARD_COEF * reward_difference  # Normalize by height * width

            # Reconstruction loss
            target_pixels = states[:, t]
            predicted = torch.sigmoid(decoder(z))
            rec_loss_batch = decoder_pixel_loss(target_pixels, predicted)

            if truncate_bptt and t > 1:
                z.detach_()

            rec_loss = torch.mean(rec_loss_batch * active_mask)
            ts.collect('Reconstruction t={}'.format(t), rec_loss)
            loss += rec_loss

            # Apply activation L1 loss
            #l1_values = z.abs().mean(-1).mean(-1).mean(-1)
            #l1_loss = ACTIVATION_L1_COEF * torch.mean(l1_values * active_mask)
            #ts.collect('L1 t={}'.format(t), l1_loss)
            #loss += theta * l1_loss

            # Predict transition to the next state
            onehot_a = torch.eye(num_actions)[actions[:, t]].cuda()
            new_z = transition(z, onehot_a)

            # Apply transition L1 loss
            #t_l1_values = ((new_z - z).abs().mean(-1).mean(-1).mean(-1))
            #t_l1_loss = TRANSITION_L1_COEF * torch.mean(t_l1_values * active_mask)
            #ts.collect('T-L1 t={}'.format(t), t_l1_loss)
            #loss += theta * t_l1_loss

            z = new_z

            if enable_latent_overshooting:
                # Latent Overshooting, Hafner et al.
                lo_z_set[t] = encoder(states[:, t - 1:t + 2])

                # For each previous t_left, step forward to t
                for t_left in range(1, t):
                    a = torch.eye(num_actions)[actions[:, t - 1]].cuda()
                    lo_z_set[t_left] = transition(lo_z_set[t_left], a)
                for t_a in range(2, t - 1):
                    # It's like TD but only N:1 for all N
                    predicted_activations = lo_z_set[t_a]
                    target_activations = lo_z_set[t].detach()
                    lo_loss_batch = latent_state_loss(target_activations,
                                                      predicted_activations)
                    lo_loss += td_lambda_coef * torch.mean(
                        lo_loss_batch * active_mask)

        if enable_latent_overshooting:
            ts.collect('LO total', lo_loss)
            loss += theta * lo_loss

        # COUNTERFACTUAL DISENTANGLEMENT REGULARIZATION
        # Suppose that our representation is ideally, perfectly disentangled
        # Then the PGM has no edges, the causal graph is just nodes with no relationships
        # In this case, it should be true that intervening on any one factor has no effect on the others
        # One fun way of intervening is swapping factors, a la FactorVAE
        # If we intervene on some dimensions, the other dimensions should be unaffected
        if enable_cf_shuffle_loss and train_iter % CF_REGULARIZATION_RATE == 0:
            # Counterfactual scenario A: our memory of what really happened
            z_cf_a = z.clone()
            # Counterfactual scenario B: a bizzaro world where two dimensions are swapped
            z_cf_b = z_orig
            unswapped_factor_map = torch.ones((batch_size, latent_dim)).cuda()
            for i in range(batch_size):
                idx_a = np.random.randint(latent_dim)
                idx_b = np.random.randint(latent_dim)
                unswapped_factor_map[i, idx_a] = 0
                unswapped_factor_map[i, idx_b] = 0
                z_cf_b[i, idx_a], z_cf_b[i,
                                         idx_b] = z_cf_b[i,
                                                         idx_b], z_cf_b[i,
                                                                        idx_a]
            # But we take the same actions
            for t in range(1, counterfactual_horizon):
                onehot_a = torch.eye(num_actions)[actions[:, t]].cuda()
                z_cf_b = transition(z_cf_b, onehot_a)
            # Every UNSWAPPED dimension should be as similar as possible to its bizzaro-world equivalent
            cf_loss = torch.abs(z_cf_a - z_cf_b).mean(-1).mean(
                -1) * unswapped_factor_map
            cf_loss = CF_REGULARIZATION_LAMBDA * torch.mean(
                cf_loss.mean(-1) * active_mask)
            loss += cf_loss
            ts.collect('CF Disentanglement Loss', cf_loss)

        # COUNTERFACTUAL ACTION-CONTROL REGULARIZATION
        # In difficult POMDPs, deep neural networks can suffer from learned helplessness
        # They learn, rationally, that their actions have no causal influence on the reward
        # This is undesirable: the learned model should assume that outcomes are controllable
        if enable_control_bias_loss and train_iter % CF_REGULARIZATION_RATE == 0:
            # Counterfactual scenario A: our memory of what really happened
            z_cf_a = z.clone()
            # Counterfactual scenario B: our imagination of what might have happened
            z_cf_b = z_orig
            # Instead of the regular actions, apply an alternate policy
            cf_actions = actions.copy()
            np.random.shuffle(cf_actions)
            for t in range(1, counterfactual_horizon):
                onehot_a = torch.eye(num_actions)[cf_actions[:, t]].cuda()
                z_cf_b = transition(z_cf_b, onehot_a)
            eps = .001  # for numerical stability
            cf_loss = -torch.log(
                torch.abs(z_cf_a - z_cf_b).mean(-1).mean(-1).mean(-1) + eps)
            cf_loss = CF_REGULARIZATION_LAMBDA * torch.mean(
                cf_loss * active_mask)
            loss += cf_loss
            ts.collect('CF Control Bias Loss', cf_loss)

        loss.backward()

        from torch.nn.utils.clip_grad import clip_grad_value_
        clip_grad_value_(encoder.parameters(), 0.1)
        clip_grad_value_(transition.parameters(), 0.1)
        clip_grad_value_(decoder.parameters(), 0.1)

        opt_pred.step()
        if not args.finetune_reward:
            opt_enc.step()
            opt_dec.step()
            opt_trans.step()
        ts.print_every(10)
    print(ts)
    print('Finished')
Exemplo n.º 9
0
def main():
    # Create a 40x40 monochrome image autoencoder

    dataset = build_dataset(size=iters)

    def get_batch(size=32):
        idx = np.random.randint(len(dataset) - size)
        examples = dataset[idx:idx + size]
        return zip(*examples)

    encoder = Encoder(latent_size, mid_size)
    generator = Generator(latent_size)
    decoder = Decoder(latent_size, mid_size)
    opt_encoder = optim.Adam(encoder.parameters())
    opt_generator = optim.Adam(generator.parameters())
    opt_decoder = optim.Adam(decoder.parameters())

    ts = TimeSeries('Training', iters)

    # Train the network on the denoising autoencoder task
    for i in range(iters):
        encoder.train()
        generator.train()
        decoder.train()

        opt_encoder.zero_grad()
        opt_generator.zero_grad()
        opt_decoder.zero_grad()

        batch_input, batch_view_output, batch_target = get_batch()
        x = torch.Tensor(batch_input).cuda()
        y = torch.Tensor(batch_view_output).cuda()
        p = torch.Tensor(batch_target).cuda()

        z_enc = encoder(x)
        z_gen = generator(p)
        x_hat = decoder(z_enc)
        loss_aae = F.binary_cross_entropy(x_hat, y)
        z_dot_product = (z_enc * z_gen).sum(-1)
        z_enc_norm = torch.norm(z_enc, dim=1)
        z_gen_norm = torch.norm(z_gen, dim=1)
        loss_gen = z_dot_product / z_enc_norm / z_gen_norm
        loss_gen = z_enc.shape[0] - loss_gen.sum()

        if loss_aae < 0.01:
            k_gen = 1
            k_aae = 0.01
        else:
            k_gen = 0.1
            k_aae = 1

        loss = k_gen * loss_gen + k_aae * loss_aae
        ts.collect('Reconstruction loss', loss_aae)
        ts.collect('Generation loss', loss_gen)

        loss.backward()
        opt_encoder.step()
        opt_generator.step()
        opt_decoder.step()

        encoder.eval()
        generator.eval()
        decoder.eval()

        # if i % 25 == 0:
        #     filename = 'reconstructions/iter_{:06}_reconstruction.jpg'.format(i)
        #     x = torch.Tensor(demo_batch).cuda()
        #     z = encoder(x)
        #     x_hat = generator(z)
        #     img = torch.cat([x[:4], x_hat[:4]])
        #     caption = 'iter {}: orig. vs reconstruction'.format(i)
        #     imutil.show(img, filename=filename, resize_to=(256,512), img_padding=10, caption=caption, font_size=8)
        #     vid.write_frame(img, resize_to=(256,512), img_padding=10, caption=caption, font_size=12)
        ts.print_every(2)
Exemplo n.º 10
0
def train_laeae(ae,
                lae,
                dataset,
                iters_together=2000,
                iters_splitted=5000,
                batch_size=32,
                save_every=0,
                save_path=None,
                print_every_seconds=10):
    ts = TimeSeries('Training ae', iters_together + iters_splitted)
    opt_ae = optim.Adam(ae.parameters(), lr=2e-4)
    dataloader = DataLoader(dataset,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=1)
    opt_lae = optim.Adam(lae.parameters(), lr=2e-4)
    i = 0
    print_numbers = 0
    while i < iters_together:
        for i_batch, batch in enumerate(dataloader):
            i += 1
            if i > iters_together:
                break
            ae.train()
            x, m, p = [o.cuda() for o in batch]
            x_hat = ae.forward(x)

            bootstrap_ratio = 4
            if bootstrap_ratio > 1:
                mse = torch.flatten((x_hat - x)**2)
                loss_aae = torch.mean(
                    torch.topk(mse,
                               mse.numel() // bootstrap_ratio)[0])
            else:
                loss_aae = F.mse_loss(x, x_hat)
            z = ae.encoder.forward(x)
            z_hat = ae.lae.forward(z)
            loss_sim_latent = F.mse_loss(z, z_hat)
            loss = loss_aae + 0.01 * loss_sim_latent
            ts.collect("Reconstruction AE loss", loss_aae)
            ts.collect("Reconstruction LAE loss", loss_sim_latent)

            opt_ae.zero_grad()
            loss.backward()

            opt_ae.step()
            ae.eval()
            ts.print_every(print_every_seconds)
            if save_every != 0 and save_path is not None and i % save_every == 0:
                print_batch(x, x_hat, save_path)

    ae.use_lae(False)

    while i < iters_together + iters_splitted:
        for i_batch, batch in enumerate(dataloader):
            i += 1
            if i > iters_together + iters_splitted:
                break
            ae.train()
            x, m, p = [o.cuda() for o in batch]
            x_hat = ae.forward(x)

            bootstrap_ratio = 4
            if bootstrap_ratio > 1:
                mse = torch.flatten((x_hat - x)**2)
                loss_aae = torch.mean(
                    torch.topk(mse,
                               mse.numel() // bootstrap_ratio)[0])
            else:
                loss_aae = F.mse_loss(x, x_hat)
            ts.collect("Reconstruction AE loss", loss_aae)

            opt_ae.zero_grad()
            loss_aae.backward()

            opt_ae.step()
            ae.eval()

            # ------------------ LAE ------------------------ #
            lae.train()
            z = ae.encoder.forward(x)
            z_hat = lae.forward(z)
            if bootstrap_ratio > 1:
                mse = torch.flatten((z_hat - z)**2)
                loss_lae = torch.mean(
                    torch.topk(mse,
                               mse.numel() // bootstrap_ratio)[0])
            else:
                loss_lae = F.mse_loss(z, z_hat)

            ts.collect("Reconstruction LAE loss", loss_aae)

            opt_lae.zero_grad()
            loss_lae.backward()

            opt_lae.step()
            lae.eval()

            ts.print_every(print_every_seconds)
            if save_every != 0 and save_path is not None and i % save_every == 0:
                print_batch(x, x_hat, save_path)
Exemplo n.º 11
0
        netC.zero_grad()
        predictions = netC(data_batch, ts)
        loss = F.cross_entropy(predictions, labels)
        loss.backward()
        optimizerC.step()

        # For classification, we want to track accuracy during training
        # Note that accuracy is our true objective, but we optimize with
        # cross_entropy because it is smoothly differentiable
        pred_confidence, pred_argmax = predictions.max(dim=1)
        correct = torch.sum(pred_argmax == labels)
        accuracy = float(correct) / len(data_batch)

        ts.collect('Training Loss', loss)
        ts.collect('Training Accuracy', accuracy)
        ts.print_every(n_sec=4)

    total_correct = 0
    for data_batch, labels in test_dataloader:
        data_batch = data_batch.to(device)
        labels = labels.to(device)

        predictions = netC(data_batch, ts)
        pred_confidence, pred_argmax = predictions.max(dim=1)
        correct = torch.sum(pred_argmax == labels)
        accuracy = float(correct) / len(data_batch)
        total_correct += correct

        ts.collect('Testing Loss', loss)
        ts.collect('Testing Accuracy', accuracy)
        ts.print_every(n_sec=4)