Ejemplo n.º 1
0
def test_lstm_saturation_embed_runs():
    save_path = TEMP_DIRNAME
    # Run 2
    timeseries_method = 'last_timestep'

    model = torch.nn.Sequential().to(device)
    lstm = torch.nn.LSTM(10, 88, 2)
    model.add_module('lstm', lstm)

    writer = CSVandPlottingWriter(save_path, fontsize=16)
    saturation = SaturationTracker(save_path, [writer],
                                   model,
                                   stats=['lsat', 'idim', 'embed'],
                                   timeseries_method=timeseries_method,
                                   device=device)

    input = torch.randn(5, 3, 10).to(device)
    output, (hn, cn) = model(input)
    assert saturation.logs['train-covariance-matrix'][
        'lstm'].saved_samples.shape == torch.Size([5, 88])

    input = torch.randn(8, 3, 10)
    output, (hn, cn) = model(input)
    assert saturation.logs['train-covariance-matrix'][
        'lstm'].saved_samples.shape == torch.Size([8, 88])
    saturation.add_saturations()
    saturation.close()
    return True
Ejemplo n.º 2
0
def test_lstm_saturation_runs():
    save_path = TEMP_DIRNAME

    # Run 1
    timeseries_method = 'timestepwise'

    model = torch.nn.Sequential().to(device)
    lstm = torch.nn.LSTM(10, 88, 2)
    lstm.name = 'lstm2'
    model.add_module('lstm', lstm)

    writer = CSVandPlottingWriter(save_path,
                                  fontsize=16,
                                  primary_metric='test_accuracy')
    saturation = SaturationTracker(save_path, [writer],
                                   model,
                                   stats=['lsat', 'idim'],
                                   timeseries_method=timeseries_method,
                                   device=device)

    input = torch.randn(5, 3, 10).to(device)
    output, (hn, cn) = model(input)
    saturation.close()
Ejemplo n.º 3
0
def test_dense_saturation_runs():
    save_path = TEMP_DIRNAME
    model = torch.nn.Sequential(torch.nn.Linear(10, 88)).to(device)

    writer = CSVandPlottingWriter(save_path,
                                  fontsize=16,
                                  primary_metric='test_accuracy')
    _ = SaturationTracker(save_path, [writer],
                          model,
                          stats=['lsat', 'idim'],
                          device=device)

    test_input = torch.randn(5, 10).to(device)
    _ = model(test_input)
    return True
Ejemplo n.º 4
0
def test_conv_saturation_runs_with_pca():
    save_path = TEMP_DIRNAME
    model = torch.nn.Sequential(torch.nn.Conv2d(4, 88, (3, 3)),
                                Conv2DPCALayer(88)).to(device)

    writer = CSVandPlottingWriter(save_path,
                                  fontsize=16,
                                  primary_metric='test_accuracy')
    _ = SaturationTracker(save_path, [writer],
                          model,
                          stats=['lsat', 'idim'],
                          device=device)

    test_input = torch.randn(32, 4, 10, 10).to(device)
    _ = model(test_input)
    model.eval()
    _ = model(test_input)
    return True
Ejemplo n.º 5
0
def test_conv_saturation_runs_with_pca_injecting_random_directions():
    save_path = TEMP_DIRNAME
    model = torch.nn.Sequential(torch.nn.Conv2d(4, 88, (3, 3)),
                                Conv2DPCALayer(88)).to(device)

    writer = CSVandPlottingWriter(save_path,
                                  fontsize=16,
                                  primary_metric='test_accuracy')
    _ = SaturationTracker(save_path, [writer],
                          model,
                          stats=['lsat', 'idim'],
                          device=device)

    test_input = torch.randn(32, 4, 10, 10).to(device)
    _ = model(test_input)
    model.eval()
    x = model(test_input)
    change_all_pca_layer_thresholds_and_inject_random_directions(0.99, model)
    y = model(test_input)
    return x != y
Ejemplo n.º 6
0
def test_dense_saturation_runs_with_many_writers():
    save_path = TEMP_DIRNAME
    model = torch.nn.Sequential(torch.nn.Linear(10, 88)).to(device)

    writer = CSVandPlottingWriter(save_path,
                                  fontsize=16,
                                  primary_metric='test_accuracy')
    writer2 = NPYWriter(save_path)
    writer3 = PrintWriter()
    sat = SaturationTracker(save_path, [writer, writer2, writer3],
                            model,
                            stats=['lsat', 'idim'],
                            device=device)

    test_input = torch.randn(5, 10).to(device)
    _ = model(test_input)
    sat.add_scalar("test_accuracy", 1.0)
    sat.add_saturations()

    return True
Ejemplo n.º 7
0
    x = torch.randn(N, D_in)
    y = torch.randn(N, D_out)
    x_test = torch.randn(N, D_in)
    y_test = torch.randn(N, D_out)

    # You can watch specific layers by handing them to delve as a list.
    # Also, you can hand over the entire Module-object to delve and let delve search for recordable layers.
    model = TwoLayerNet(D_in, H, D_out)

    x, y, model = x.to(device), y.to(device), model.to(device)
    x_test, y_test = x_test.to(device), y_test.to(device)

    layers = [model.linear1, model.linear2]
    stats = SaturationTracker('regression/h{}'.format(h),
                              save_to="plotcsv",
                              modules=layers,
                              device=device,
                              stats=["lsat", "lsat_eval"])

    loss_fn = torch.nn.MSELoss(reduction='sum')
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
    steps_iter = trange(2000, desc='steps', leave=True, position=0)
    steps_iter.write("{:^80}".format(
        "Regression - TwoLayerNet - Hidden layer size {}".format(h)))
    for step in steps_iter:
        # training step
        model.train()
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        optimizer.zero_grad()
        loss.backward()
Ejemplo n.º 8
0
    epochs = 10

    net = NET()
    if torch.cuda.is_available():
        net.cuda()

    net.to(device)
    logging_dir = 'net/simpson_h2-{}'.format(2)

    stats = SaturationTracker(savefile=logging_dir,
                              save_to='plot',
                              modules=net,
                              include_conv=False,
                              stats=['lsat'],
                              max_samples=1024,
                              verbose=True,
                              writer_args={
                              'figsize': [30, 30],
                              'fontsize': 32
                          },
                              conv_method='mean',
                              device='cpu')

    #net = nn.DataParallel(net, device_ids=['cuda:0', 'cuda:1'])
    eps = torch.Tensor([1e-10]).cuda()

    def loss_fn(recon_x, x, mu, logvar, eps):
        BCE = F.binary_cross_entropy(recon_x + eps, x, reduction='sum')
        KLD = -0.5 * torch.sum(1 + logvar - mu**2 - logvar.exp())
        return (BCE + KLD) / x.size(0)
Ejemplo n.º 9
0
    test_loader = DataLoader(test_data,
                             batch_size=1024,
                             shuffle=False,
                             pin_memory=True)

    # instantiate model
    model = vgg11(num_classes=10).to(device)

    # instantiate optimizer and loss
    optimizer = Adam(params=model.parameters())
    criterion = CrossEntropyLoss().to(device)

    # initialize delve
    tracker = SaturationTracker("experiment",
                                save_to="plotcsv",
                                stats=["lsat"],
                                modules=model,
                                device=device)

    # begin training
    for epoch in range(10):
        # only record saturation for uneven epochs
        if epoch % 2 == 1:
            tracker.resume()
        else:
            tracker.stop()
        model.train()
        for (images, labels) in tqdm(train_loader):
            images, labels = images.to(device), labels.to(device)
            prediction = model(images)
            optimizer.zero_grad(set_to_none=True)