Exemplo n.º 1
0
def test_lstm_saturation_embed_runs():
    save_path = TEMP_DIRNAME
    # Run 2
    timeseries_method = 'last_timestep'

    model = torch.nn.Sequential().to(device)
    lstm = torch.nn.LSTM(10, 88, 2)
    model.add_module('lstm', lstm)

    writer = CSVandPlottingWriter(save_path, fontsize=16)
    saturation = SaturationTracker(save_path, [writer],
                                   model,
                                   stats=['lsat', 'idim', 'embed'],
                                   timeseries_method=timeseries_method,
                                   device=device)

    input = torch.randn(5, 3, 10).to(device)
    output, (hn, cn) = model(input)
    assert saturation.logs['train-covariance-matrix'][
        'lstm'].saved_samples.shape == torch.Size([5, 88])

    input = torch.randn(8, 3, 10)
    output, (hn, cn) = model(input)
    assert saturation.logs['train-covariance-matrix'][
        'lstm'].saved_samples.shape == torch.Size([8, 88])
    saturation.add_saturations()
    saturation.close()
    return True
Exemplo n.º 2
0
def test_dense_saturation_runs():
    save_path = TEMP_DIRNAME
    model = torch.nn.Sequential(torch.nn.Linear(10, 88)).to(device)

    writer = CSVandPlottingWriter(save_path,
                                  fontsize=16,
                                  primary_metric='test_accuracy')
    _ = SaturationTracker(save_path, [writer],
                          model,
                          stats=['lsat', 'idim'],
                          device=device)

    test_input = torch.randn(5, 10).to(device)
    _ = model(test_input)
    return True
Exemplo n.º 3
0
def test_conv_saturation_runs_with_pca():
    save_path = TEMP_DIRNAME
    model = torch.nn.Sequential(torch.nn.Conv2d(4, 88, (3, 3)),
                                Conv2DPCALayer(88)).to(device)

    writer = CSVandPlottingWriter(save_path,
                                  fontsize=16,
                                  primary_metric='test_accuracy')
    _ = SaturationTracker(save_path, [writer],
                          model,
                          stats=['lsat', 'idim'],
                          device=device)

    test_input = torch.randn(32, 4, 10, 10).to(device)
    _ = model(test_input)
    model.eval()
    _ = model(test_input)
    return True
Exemplo n.º 4
0
    def _initialize_tracker(self):
        writer = CSVandPlottingWriter(self._save_path.replace('.csv', ''),
                                      primary_metric='test_accuracy')

        self._tracker = CheckLayerSat(
            self._save_path.replace('.csv', ''), [writer],
            self.model,
            ignore_layer_names='convolution',
            stats=['lsat', 'idim'],
            sat_threshold=self.delta,
            verbose=False,
            conv_method=self.conv_method,
            log_interval=1,
            device=self.device_sat,
            reset_covariance=True,
            max_samples=None,
            initial_epoch=self._initial_epoch,
            interpolation_strategy='nearest'
            if self.downsampling is not None else None,
            interpolation_downsampling=self.downsampling)
Exemplo n.º 5
0
def test_conv_saturation_runs_with_pca_injecting_random_directions():
    save_path = TEMP_DIRNAME
    model = torch.nn.Sequential(torch.nn.Conv2d(4, 88, (3, 3)),
                                Conv2DPCALayer(88)).to(device)

    writer = CSVandPlottingWriter(save_path,
                                  fontsize=16,
                                  primary_metric='test_accuracy')
    _ = SaturationTracker(save_path, [writer],
                          model,
                          stats=['lsat', 'idim'],
                          device=device)

    test_input = torch.randn(32, 4, 10, 10).to(device)
    _ = model(test_input)
    model.eval()
    x = model(test_input)
    change_all_pca_layer_thresholds_and_inject_random_directions(0.99, model)
    y = model(test_input)
    return x != y
Exemplo n.º 6
0
def test_dense_saturation_runs_with_many_writers():
    save_path = TEMP_DIRNAME
    model = torch.nn.Sequential(torch.nn.Linear(10, 88)).to(device)

    writer = CSVandPlottingWriter(save_path,
                                  fontsize=16,
                                  primary_metric='test_accuracy')
    writer2 = NPYWriter(save_path)
    writer3 = PrintWriter()
    sat = SaturationTracker(save_path, [writer, writer2, writer3],
                            model,
                            stats=['lsat', 'idim'],
                            device=device)

    test_input = torch.randn(5, 10).to(device)
    _ = model(test_input)
    sat.add_scalar("test_accuracy", 1.0)
    sat.add_saturations()

    return True
Exemplo n.º 7
0
def test_lstm_saturation_runs():
    save_path = TEMP_DIRNAME

    # Run 1
    timeseries_method = 'timestepwise'

    model = torch.nn.Sequential().to(device)
    lstm = torch.nn.LSTM(10, 88, 2)
    lstm.name = 'lstm2'
    model.add_module('lstm', lstm)

    writer = CSVandPlottingWriter(save_path,
                                  fontsize=16,
                                  primary_metric='test_accuracy')
    saturation = SaturationTracker(save_path, [writer],
                                   model,
                                   stats=['lsat', 'idim'],
                                   timeseries_method=timeseries_method,
                                   device=device)

    input = torch.randn(5, 3, 10).to(device)
    output, (hn, cn) = model(input)
    saturation.close()
Exemplo n.º 8
0
    def __init__(self,
                 model,
                 train_loader,
                 test_loader,
                 epochs=200,
                 batch_size=60,
                 run_id=0,
                 logs_dir='logs',
                 device='cpu',
                 saturation_device=None,
                 optimizer='None',
                 plot=True,
                 compute_top_k=False,
                 data_prallel=False,
                 conv_method='channelwise',
                 thresh=.99):
        self.saturation_device = device if saturation_device is None else saturation_device
        self.device = device
        self.model = model
        self.epochs = epochs
        self.plot = plot
        self.compute_top_k = compute_top_k

        if 'cuda' in device:
            cudnn.benchmark = True

        self.train_loader = train_loader
        self.test_loader = test_loader

        self.criterion = nn.CrossEntropyLoss()
        print('Checking for optimizer for {}'.format(optimizer))
        #optimizer = str(optimizer)
        if optimizer == "adam":
            print('Using adam')
            self.optimizer = optim.Adam(model.parameters())
        elif optimizer == 'bad_lr_adam':
            print('Using adam with to large learning rate')
            self.optimizer = optim.Adam(model.parameters(), lr=0.01)
        elif optimizer == "SGD":
            print('Using SGD')
            self.optimizer = optim.SGD(model.parameters(),
                                       lr=0.5,
                                       momentum=0.9)
        elif optimizer == "LRS":
            print('Using LRS')
            self.optimizer = optim.SGD(model.parameters(),
                                       lr=0.01,
                                       momentum=0.9)
            self.lr_scheduler = optim.lr_scheduler.StepLR(self.optimizer, 5)
        elif optimizer == "radam":
            print('Using radam')
            self.optimizer = RAdam(model.parameters())
        else:
            raise ValueError('Unknown optimizer {}'.format(optimizer))
        self.opt_name = optimizer
        save_dir = os.path.join(logs_dir, model.name, train_loader.name)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        self.savepath = os.path.join(
            save_dir,
            f'{model.name}_bs{batch_size}_e{epochs}_t{int(thresh*1000)}_id{run_id}.csv'
        )
        self.experiment_done = False
        if os.path.exists(self.savepath):
            trained_epochs = len(pd.read_csv(self.savepath, sep=';'))

            if trained_epochs >= epochs:
                self.experiment_done = True
                print(
                    f'Experiment Logs for the exact same experiment with identical run_id was detecting, training will be skipped, consider using another run_id'
                )
        self.parallel = data_prallel
        if data_prallel:
            self.model = nn.DataParallel(self.model, ['cuda:0', 'cuda:1'])
        writer = CSVandPlottingWriter(self.savepath.replace('.csv', ''),
                                      fontsize=16,
                                      primary_metric='test_accuracy')
        self.pooling_strat = conv_method
        print('Settomg Satiraton recording threshold to', thresh)
        self.stats = CheckLayerSat(self.savepath.replace('.csv', ''),
                                   writer,
                                   model,
                                   ignore_layer_names='convolution',
                                   stats=['lsat'],
                                   sat_threshold=.99,
                                   verbose=False,
                                   conv_method=conv_method,
                                   log_interval=1,
                                   device=self.saturation_device,
                                   reset_covariance=True,
                                   max_samples=None)
Exemplo n.º 9
0
    def __init__(self,
                 model,
                 train_loader,
                 test_loader,
                 epochs=200,
                 batch_size=60,
                 run_id=0,
                 logs_dir='logs',
                 device='cpu',
                 saturation_device=None,
                 optimizer='None',
                 plot=True,
                 compute_top_k=False,
                 data_prallel=False,
                 conv_method='channelwise',
                 thresh=.99,
                 half_precision=False,
                 downsampling=None):
        self.saturation_device = device if saturation_device is None else saturation_device
        self.device = device
        self.model = model
        self.epochs = epochs
        self.plot = plot
        self.compute_top_k = compute_top_k

        if 'cuda' in device:
            cudnn.benchmark = True

        self.train_loader = train_loader
        self.test_loader = test_loader

        self.criterion = nn.CrossEntropyLoss()
        print('Checking for optimizer for {}'.format(optimizer))
        #optimizer = str(optimizer)
        if optimizer == "adam":
            print('Using adam')
            self.optimizer = optim.Adam(model.parameters())
        elif optimizer == "adam_lr":
            print("Using adam with higher learning rate")
            self.optimizer = optim.Adam(model.parameters(), lr=0.01)
        elif optimizer == 'adam_lr2':
            print('Using adam with to large learning rate')
            self.optimizer = optim.Adam(model.parameters(), lr=0.0001)
        elif optimizer == "SGD":
            print('Using SGD')
            self.optimizer = optim.SGD(model.parameters(),
                                       momentum=0.9,
                                       weight_decay=5e-4)
        elif optimizer == "LRS":
            print('Using LRS')
            self.optimizer = optim.SGD(model.parameters(),
                                       lr=0.1,
                                       momentum=0.9,
                                       weight_decay=5e-4)
            self.lr_scheduler = optim.lr_scheduler.StepLR(
                self.optimizer, self.epochs // 3)
        elif optimizer == "radam":
            print('Using radam')
            self.optimizer = RAdam(model.parameters())
        else:
            raise ValueError('Unknown optimizer {}'.format(optimizer))
        self.opt_name = optimizer
        save_dir = os.path.join(logs_dir, model.name, train_loader.name)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        self.savepath = os.path.join(
            save_dir,
            f'{model.name}_bs{batch_size}_e{epochs}_dspl{downsampling}_t{int(thresh*1000)}_id{run_id}.csv'
        )
        self.experiment_done = False
        if os.path.exists(self.savepath):
            trained_epochs = len(pd.read_csv(self.savepath, sep=';'))

            if trained_epochs >= epochs:
                self.experiment_done = True
                print(
                    f'Experiment Logs for the exact same experiment with identical run_id was detected, training will be skipped, consider using another run_id'
                )
        if os.path.exists((self.savepath.replace('.csv', '.pt'))):
            self.model.load_state_dict(
                torch.load(self.savepath.replace('.csv',
                                                 '.pt'))['model_state_dict'])
            if data_prallel:
                self.model = nn.DataParallel(self.model)
            self.model = self.model.to(self.device)
            if half_precision:
                self.model = self.model.half()
            self.optimizer.load_state_dict(
                torch.load(self.savepath.replace('.csv', '.pt'))['optimizer'])
            self.start_epoch = torch.load(self.savepath.replace(
                '.csv', '.pt'))['epoch'] + 1
            initial_epoch = self._infer_initial_epoch(self.savepath)
            print('Resuming existing run, starting at epoch', self.start_epoch,
                  'from', self.savepath.replace('.csv', '.pt'))
        else:
            if half_precision:
                self.model = self.model.half()
            self.start_epoch = 0
            initial_epoch = 0
            self.parallel = data_prallel
            if data_prallel:
                self.model = nn.DataParallel(self.model)
            self.model = self.model.to(self.device)
        writer = CSVandPlottingWriter(self.savepath.replace('.csv', ''),
                                      fontsize=16,
                                      primary_metric='test_accuracy')
        writer2 = NPYWriter(self.savepath.replace('.csv', ''))
        self.pooling_strat = conv_method
        print('Settomg Satiraton recording threshold to', thresh)
        self.half = half_precision

        self.stats = CheckLayerSat(self.savepath.replace('.csv', ''), [writer],
                                   model,
                                   ignore_layer_names='convolution',
                                   stats=['lsat', 'idim'],
                                   sat_threshold=.99,
                                   verbose=False,
                                   conv_method=conv_method,
                                   log_interval=1,
                                   device=self.saturation_device,
                                   reset_covariance=True,
                                   max_samples=None,
                                   initial_epoch=initial_epoch,
                                   interpolation_strategy='nearest'
                                   if downsampling is not None else None,
                                   interpolation_downsampling=4)