Exemplo n.º 1
0
def test_profile(num_replicas):
    import adaptdl.checkpoint
    from adaptdl.env import num_restarts
    from adaptdl.torch._metrics import (
            profile_step_start, profile_sync_time,
            profile_step_commit, _metrics_state)
    if num_restarts() == 0:
        profile = _metrics_state().profile
        assert len(profile) == 0
        # Profile local_bsz=1 but don't commit.
        profile_step_start(1)
        profile_sync_time(1.0)
        # Profile local_bsz=2 and commit.
        profile_step_start(2)
        profile_sync_time(1.0)
        profile_sync_time(2.0)
        profile_step_commit()
        # Ensure profile is updated correctly.
        profile = _metrics_state().profile
        key = (1, 1, 2)
        assert len(profile) == 1
        assert profile[key]["accum_count"] == 0
        assert profile[key]["optim_count"] == 1
        assert profile[key]["optim_sync_time"] == 3.0
        assert profile[key]["optim_step_time"] > 0.0
        # Checkpoint and restart.
        adaptdl.checkpoint.save_all_states()
        return num_replicas
    elif num_restarts() == 1:
        profile = _metrics_state().profile
        # Ensure checkpoint is loaded correctly.
        key = (1, 1, 2)
        assert len(profile) == 1
        assert profile[key]["accum_count"] == 0
        assert profile[key]["optim_count"] == 1
        assert profile[key]["optim_sync_time"] == 3.0
        assert profile[key]["optim_step_time"] > 0.0
        # Profile local_bsz=3 and commit twice.
        profile_step_start(3)
        profile_sync_time(2.0)
        profile_sync_time(3.0)
        profile_step_commit()
        key = (1, num_replicas, 3)
        old_step_time = profile[key]["optim_step_time"]
        profile_step_start(3)
        profile_sync_time(3.0)
        profile_sync_time(4.0)
        profile_step_commit()
        # Ensure profile is updated correctly.
        assert len(profile) == 2
        assert profile[key]["accum_count"] == 0
        assert profile[key]["optim_count"] == 2
        assert profile[key]["optim_sync_time"] == 12.0
        assert profile[key]["optim_step_time"] > old_step_time > 0.0
Exemplo n.º 2
0
def test_profile_accumulation(num_replicas):
    import adaptdl.checkpoint
    from adaptdl.env import num_restarts
    from adaptdl.torch._metrics import (
            profile_step_start, profile_sync_time,
            profile_step_commit, _metrics_state, _fit_perf_params)
    if num_restarts() == 0:
        profile = _metrics_state().profile
        assert len(profile) == 0
        # Profile local_bsz=1 but don't commit.
        profile_step_start(1)
        profile_sync_time(1.0)
        # Profile local_bsz=2 and commit.
        profile_step_start(2)
        profile_step_commit(accumulation_step=True)
        profile_step_start(2)
        profile_step_commit(accumulation_step=True)
        profile_step_start(2)
        profile_sync_time(4.0)
        profile_step_commit(accumulation_step=False)
        profile_step_start(5)
        profile_step_commit(accumulation_step=True)
        profile_step_start(5)
        profile_step_commit(accumulation_step=True)
        profile_step_start(5)
        profile_sync_time(6.0)
        profile_step_commit(accumulation_step=False)
        # Ensure profile is updated correctly.
        profile = _metrics_state().profile
        key = (1, 1, 2)
        assert len(profile) == 2
        assert profile[key]["accum_count"] == 2
        assert profile[key]["optim_count"] == 1
        assert profile[key]["optim_sync_time"] == 4.0
        assert profile[key]["accum_step_time"] > 0.0
        assert profile[key]["optim_step_time"] > 0.0
        profile_step_start(3)
        profile_step_commit(accumulation_step=True)
        profile_step_start(3)
        profile_step_commit(accumulation_step=True)
        # Check that fitting parameters works even
        # without a final accumulation_step=False commit
        for val in profile.values():
            # Ensure step time is at least sync time.
            val["optim_step_time"] += val["optim_sync_time"]
        _fit_perf_params()
        # Checkpoint and restart.
        adaptdl.checkpoint.save_all_states()
        return num_replicas
    elif num_restarts() == 1:
        profile = _metrics_state().profile
        # Ensure checkpoint is loaded correctly.
        key = (1, 1, 2)
        assert len(profile) == 3
        assert profile[key]["accum_count"] == 2
        assert profile[key]["optim_count"] == 1
        assert profile[key]["optim_sync_time"] == 4.0
        assert profile[key]["optim_step_time"] > 0.0
        # Profile local_bsz=3 and commit twice.
        profile_step_start(3)
        profile_sync_time(2.0)
        profile_sync_time(3.0)
        profile_step_commit()
        key = (1, num_replicas, 3)
        old_step_time = profile[key]["optim_step_time"]
        profile_step_start(3)
        profile_sync_time(3.0)
        profile_sync_time(4.0)
        profile_step_commit()
        # Ensure profile is updated correctly.
        if num_replicas == 1:
            assert len(profile) == 3
        else:
            assert len(profile) == 4
        assert profile[key]["accum_count"] == 0 if num_replicas > 1 else 2
        assert profile[key]["optim_count"] == 2
        assert profile[key]["optim_sync_time"] == 12.0
        assert profile[key]["optim_step_time"] > old_step_time > 0.0
Exemplo n.º 3
0
def train(epoch):
    iters = 0
    # For each batch in the dataloader
    stats = adl.Accumulator()
    for i, data in enumerate(dataloader, 0):
        data = data[0]
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data.to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size, ), real_label, device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Add the gradients from the all-real and all-fake batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        stats["g_loss_sum"] += errG.item()
        stats["d_loss_sum"] += errD.item()
    stats["norm"] += metrics._metrics_state().grad_params[0]
    stats["var"] += metrics._metrics_state().grad_params[1]
    stats["replicas"] += 1.0
    scheduleD.step()
    scheduleG.step()

    with stats.synchronized():
        with SummaryWriter(adaptdl.get_tensorboard_dir()) as writer:
            writer.add_scalar("Loss/G",
                              stats["g_loss_sum"] / stats["replicas"], epoch)
            writer.add_scalar("Loss/D",
                              stats["d_loss_sum"] / stats["replicas"], epoch)
            writer.add_scalar("Performance/GlobalBatchsize",
                              b_size * stats["replicas"], epoch)
            writer.add_scalar("Performance/Replicas", stats["replicas"], epoch)
            writer.add_scalar("Stats/Variance",
                              stats["norm"] / stats["replicas"], epoch)
            writer.add_scalar("Stats/Norm", stats["var"] / stats["replicas"],
                              epoch)