Example #1
0
    def predict(self, name_extra=""):
        import helpers.output_writers as ow

        model = self.model
        for name in self.config.datasets:
            dataset_config = AttrDefault(lambda: None,
                                         self.config.datasets[name])
            if dataset_config.predicting:
                sid, out = self.do_predict(name, dataset_config, model)
                for owriter_name in dataset_config.writers:
                    owcnfg = dataset_config.writers[owriter_name]
                    ow.__dict__[owcnfg['name']](sid, out, self,
                                                name + name_extra,
                                                owriter_name, **owcnfg['args'])

        if self.use_swa:
            model = self.swa_model
            for name in self.config.datasets:
                dataset_config = AttrDefault(lambda: None,
                                             self.config.datasets[name])
                if dataset_config.predicting:
                    sid, out = self.do_predict(name, dataset_config, model)
                    for owriter_name in dataset_config.writers:
                        owcnfg = dataset_config.writers[owriter_name]
                        ow.__dict__[owcnfg['name']](sid, out, self, name,
                                                    owriter_name + "_swa",
                                                    **owcnfg['args'])
Example #2
0
def fill_limits(month_data, day_data):
    limits = AttrDefault(dict)
    limits.months.timestamp.maximum = max(
        month_data, key=lambda x: x['timestamp'])['timestamp'].split()[0]
    limits.months.timestamp.minimum = min(
        month_data, key=lambda x: x['timestamp'])['timestamp'].split()[0]
    limits.months.consumption.maximum = max(
        month_data, key=lambda x: x['consumption'])['consumption']
    limits.months.consumption.minimum = min(
        month_data, key=lambda x: x['consumption'])['consumption']
    limits.months.temperature.maximum = max(
        month_data, key=lambda x: x['temperature'])['temperature']
    limits.months.temperature.minimum = min(
        month_data, key=lambda x: x['temperature'])['temperature']

    limits.days.timestamp.maximum = max(
        day_data, key=lambda x: x['timestamp'])['timestamp'].split()[0]
    limits.days.timestamp.minimum = min(
        day_data, key=lambda x: x['timestamp'])['timestamp'].split()[0]
    limits.days.consumption.maximum = max(
        day_data, key=lambda x: x['consumption'])['consumption']
    limits.days.consumption.minimum = min(
        day_data, key=lambda x: x['consumption'])['consumption']
    limits.days.temperature.maximum = max(
        day_data, key=lambda x: x['temperature'])['temperature']
    limits.days.temperature.minimum = min(
        day_data, key=lambda x: x['temperature'])['temperature']
    return limits
Example #3
0
def test_kfac_hessian():
    A, model = create_toy_model()
    data = A.t()
    data = data.repeat(7, 1)
    n = float(len(data))

    activations = {}
    hess = defaultdict(lambda: AttrDefault(float))

    def save_activations(layer, a, _):
        activations[layer] = a

    def compute_hessian(layer, _, B):
        A = activations[layer]
        hess[layer].AA += torch.einsum("ni,nj->ij", A, A)
        hess[layer].BB += torch.einsum("ni,nj->ij", B, B)

    for x in data:
        with autograd_lib.module_hook(save_activations):
            y = model(x)
            o = y.shape[1]
            loss = torch.sum(y * y) / 2

        with autograd_lib.module_hook(compute_hessian):
            autograd_lib.backprop_identity(y)

    hess0 = hess[model.layers[0]]
    result = u.kron(hess0.BB / n, hess0.AA / o)

    # check result against autograd
    loss = u.least_squares(model(data), aggregation='sum')
    hess0 = u.hessian(loss, model.layers[0].weight).reshape(4, 4)
    u.check_equal(hess0, result)
Example #4
0
def test_kfac_jacobian_mnist():
    u.seed_random(1)

    data_width = 3
    d = [data_width**2, 8, 10]
    model: u.SimpleMLP = u.SimpleMLP(d, nonlin=False)
    autograd_lib.register(model)

    batch_size = 4
    stats_steps = 2
    n = batch_size * stats_steps

    dataset = u.TinyMNIST(dataset_size=n,
                          data_width=data_width,
                          original_targets=True)
    trainloader = torch.utils.data.DataLoader(dataset,
                                              batch_size=batch_size,
                                              shuffle=False)
    train_iter = iter(trainloader)

    loss_fn = torch.nn.CrossEntropyLoss()

    activations = {}
    jacobians = defaultdict(lambda: AttrDefault(float))
    total_data = []

    # sum up statistics over n examples
    for train_step in range(stats_steps):
        data, targets = next(train_iter)
        total_data.append(data)

        activations = {}

        def save_activations(layer, A, _):
            activations[layer] = A
            jacobians[layer].AA += torch.einsum("ni,nj->ij", A, A)

        with autograd_lib.module_hook(save_activations):
            output = model(data)
            loss = loss_fn(output, targets)

        def compute_jacobian(layer, _, B):
            A = activations[layer]
            jacobians[layer].BB += torch.einsum("ni,nj->ij", B, B)
            jacobians[layer].diag += torch.einsum("ni,nj->ij", B * B, A * A)

        with autograd_lib.module_hook(compute_jacobian):
            autograd_lib.backward_jacobian(output)

    for layer in model.layers:
        jacobian0 = jacobians[layer]
        jacobian_full = torch.einsum('kl,ij->kilj', jacobian0.BB / n,
                                     jacobian0.AA / n)
        jacobian_diag = jacobian0.diag / n

        J = u.jacobian(model(torch.cat(total_data)), layer.weight)
        J_autograd = torch.einsum('noij,nokl->ijkl', J, J) / n
        u.check_equal(jacobian_full, J_autograd)

        u.check_equal(jacobian_diag, torch.einsum('ikik->ik', J_autograd))
Example #5
0
    def evaluate(self):
        model = self.model

        # TODO: compute predictions in this function (similar to train, test...)
        # this allows use "evaluations" in addition to "total_evaluations"
        # keep track inside a metrics_meter (so tp, fp, ... does not need to be computed in the eval function)

        for dataset_name in self.config.datasets:
            dataset_config = AttrDefault(lambda: None,
                                         self.config.datasets[dataset_name])
            if dataset_config.evaluating:
                print("evaluate on ", dataset_name)
                # TODO: do not allow "evaluations" because this is not called after every batch
                data_loader = self.data_loaders[dataset_name]
                eval_metrics = {}
                for ef in dataset_config._mapping.get("total_evaluations", []):
                    ev_func = get_total_evaluation(ef["name"])
                    eval_metrics = {
                        **eval_metrics,
                        **ev_func(None,
                                  model=model,
                                  data_loader=data_loader,
                                  config=self.config,
                                  current_dataset_config=dataset_config,
                                  eval_args=ef.get("eval_args", {}))
                    }

                    # logger.info('total metrics:  {}'.format(str(eval_metrics)))
                shared_globals.console.info("evaluation " + dataset_name +
                                            ":\n" + str(eval_metrics))
Example #6
0
def test_kfac_fisher_mnist():
    u.seed_random(1)

    data_width = 3
    d = [data_width**2, 8, 10]
    model: u.SimpleMLP = u.SimpleMLP(d, nonlin=False)
    autograd_lib.register(model)

    batch_size = 4
    stats_steps = 2
    n = batch_size * stats_steps

    dataset = u.TinyMNIST(dataset_size=n,
                          data_width=data_width,
                          original_targets=True)
    trainloader = torch.utils.data.DataLoader(dataset,
                                              batch_size=batch_size,
                                              shuffle=False)
    train_iter = iter(trainloader)

    loss_fn = torch.nn.CrossEntropyLoss()

    activations = {}
    fishers = defaultdict(lambda: AttrDefault(float))
    total_data = []

    # sum up statistics over n examples
    for train_step in range(stats_steps):
        data, targets = next(train_iter)
        total_data.append(data)

        activations = {}

        def save_activations(layer, A, _):
            activations[layer] = A
            fishers[layer].AA += torch.einsum("ni,nj->ij", A, A)

        with autograd_lib.module_hook(save_activations):
            output = model(data)
            loss = loss_fn(output, targets) * len(
                data)  # remove data normalization

        def compute_fisher(layer, _, B):
            A = activations[layer]
            fishers[layer].BB += torch.einsum("ni,nj->ij", B, B)
            fishers[layer].diag += torch.einsum("ni,nj->ij", B * B, A * A)

        with autograd_lib.module_hook(compute_fisher):
            autograd_lib.backward_jacobian(output)

    for layer in model.layers:
        fisher0 = fishers[layer]
        fisher_full = torch.einsum('kl,ij->kilj', fisher0.BB / n,
                                   fisher0.AA / n)
        fisher_diag = fisher0.diag / n

        u.check_equal(torch.einsum('ikik->ik', fisher_full), fisher_diag)
Example #7
0
def _test_kfac_hessian_xent_mnist():
    u.seed_random(1)

    data_width = 3
    batch_size = 2
    d = [data_width**2, 10]
    o = d[-1]
    n = batch_size
    train_steps = 1

    model: u.SimpleModel = u.SimpleFullyConnected2(d, nonlin=False, bias=True)
    autograd_lib.register(model)
    dataset = u.TinyMNIST(dataset_size=batch_size,
                          data_width=data_width,
                          original_targets=True)
    trainloader = torch.utils.data.DataLoader(dataset,
                                              batch_size=batch_size,
                                              shuffle=False)
    train_iter = iter(trainloader)

    loss_fn = torch.nn.CrossEntropyLoss()

    activations = {}
    hess = defaultdict(lambda: AttrDefault(float))
    for train_step in range(train_steps):
        data, targets = next(train_iter)

        activations = {}

        def save_activations(layer, a, _):
            activations[layer] = a

        with autograd_lib.module_hook(save_activations):
            output = model(data)
            loss = loss_fn(output, targets)

        def compute_hess(layer, _, B):
            A = activations[layer]
            hess[layer].AA += torch.einsum("ni,nj->ij", A, A)
            hess[layer].BB += torch.einsum("ni,nj->ij", B, B)

        with autograd_lib.module_hook(compute_hess):
            autograd_lib.backward_hessian(output,
                                          loss='CrossEntropy',
                                          retain_graph=True)

        hess_factored = hess[model.layers[0]]
        hess0 = torch.einsum('kl,ij->kilj', hess_factored.BB / n,
                             hess_factored.AA / o)  # hess for sum loss
        hess0 /= n  # hess for mean loss

        # compute Hessian through autograd
        H_autograd = u.hessian(loss, model.layers[0].weight)
        rel_error = torch.norm(
            (hess0 - H_autograd).flatten()) / torch.norm(H_autograd.flatten())
        assert rel_error < 0.01  # 0.0057
Example #8
0
def test_full_hessian_xent_kfac2():
    """Test with uneven layers."""
    u.seed_random(1)
    torch.set_default_dtype(torch.float64)

    batch_size = 1
    d = [3, 2]
    o = d[-1]
    n = batch_size
    train_steps = 1

    model: u.SimpleModel = u.SimpleFullyConnected2(d, nonlin=True, bias=False)
    autograd_lib.register(model)
    loss_fn = torch.nn.CrossEntropyLoss()

    data = u.to_logits(torch.tensor([[0.7, 0.2, 0.1]]))
    targets = torch.tensor([0])

    data = data.repeat([3, 1])
    targets = targets.repeat([3])
    n = len(data)

    activations = {}
    hess = defaultdict(lambda: AttrDefault(float))

    for i in range(n):

        def save_activations(layer, A, _):
            activations[layer] = A
            hess[layer].AA += torch.einsum("ni,nj->ij", A, A)

        with autograd_lib.module_hook(save_activations):
            data_batch = data[i:i + 1]
            targets_batch = targets[i:i + 1]
            Y = model(data_batch)
            o = Y.shape[1]
            loss = loss_fn(Y, targets_batch)

        def compute_hess(layer, _, B):
            hess[layer].BB += torch.einsum("ni,nj->ij", B, B)

        with autograd_lib.module_hook(compute_hess):
            autograd_lib.backward_hessian(Y, loss='CrossEntropy')

    # expand
    hess_factored = hess[model.layers[0]]
    hess0 = torch.einsum('kl,ij->kilj', hess_factored.BB / n,
                         hess_factored.AA / o)  # hess for sum loss
    hess0 /= n  # hess for mean loss

    # check against autograd
    # 0.1459
    Y = model(data)
    loss = loss_fn(Y, targets)
    hess_autograd = u.hessian(loss, model.layers[0].weight)
    u.check_equal(hess_autograd, hess0)
Example #9
0
    def __call__(self, cli_cmd):
        cli_list = shlex.split(cli_cmd)
        command = cli_list[0]

        # parse command file
        cmd_struct = self.parser_cmd(command)

        '''
        get subcommand
        in order to make argsparser works correctly,
        subcommand must appear in the cli command.
        add subcommand into cli_list
        '''
        try:
            if cli_list[1] != 'help' and cli_list[1] not in cmd_struct['commands'].keys():
                # if the subcommand not appeared in the CLI commnad
                # add 'default' as subcommand in cli_list as the second element
                cli_list.insert(1, 'default')
        except IndexError:
            # IndexError means there is only one element as main command in CLI command
            cli_list.append('default')
        finally:
            subcommand = cli_list[1]
            cmd_struct['commands'].setdefault(subcommand, {})

        cmd_property = dict(cmd_struct['commands'].get('_property', {}))
        cmd_property.update(cmd_struct['commands'][subcommand].get('property', {}))
        # set action to subcommand
        cmd_property.setdefault('action', subcommand)

        # get command params
        cmd_params = dict(cmd_struct['commands'][subcommand].get('params', {}))
        parser = TippyArgumentParser(
            prog=cli_list[0],
            description=cmd_struct.get('description')
        )
        subparsers = parser.add_subparsers()
        self.generate_cli_parser(subparsers, subcommand, cmd_params)

        if not cmd_params.get('_lock', False):
            cmd_params.update(vars(parser.parse_args(cli_list[1:])))

        cmd = AttrDefault(str, {})
        cmd.description = cmd_struct.get('description')
        cmd.property = cmd_property
        cmd.help = cmd_struct.get('help', command)
        cmd.command = command
        cmd.subcommand = subcommand
        cmd.params = cmd_params
        return cmd
Example #10
0
    def init_loaders(self):
        # maybe lazy load for predicting only runs
        for name in self.config.datasets:
            dataset_config = AttrDefault(lambda: None,
                                         self.config.datasets[name])
            if self.config[
                    'predict_only_mode'] and not dataset_config.predicting:
                continue
            # ds = self.run.get_command_function(dataset_config.dataset)()
            ds = self.dataset_manager.get_dataset(dataset_config)

            self.datasets[name] = ds
            shared_globals.logger.info("Initialized Dataset  `" + name +
                                       "` with {} Samples ".format(len(ds)))
            if dataset_config.batch_config.get(
                    "batch_sampler") == "stratified":
                shared_globals.logger.info(
                    "Initializing  StratifiedBatchSampler for " + name)
                batch_sampler = StratifiedBatchSampler(
                    ds, dataset_config.batch_config.batch_size,
                    self.config.epochs)
            elif dataset_config.batch_config.get(
                    "batch_sampler") == "sequential":
                shared_globals.logger.info(
                    "Initializing Sequential Sampler for " + name)
                sampler = SequentialSampler(ds)
                batch_sampler = BatchSampler(
                    sampler, dataset_config.batch_config.batch_size, False)
            else:
                if dataset_config.testing or dataset_config.predicting:
                    shared_globals.logger.info(
                        "Initializing Sequential Sampler for " + name)
                    sampler = SequentialSampler(ds)
                else:
                    shared_globals.logger.info(
                        "Initializing RandomSampler for " + name)
                    sampler = RandomSampler(ds)
                batch_sampler = BatchSampler(
                    sampler, dataset_config.batch_config.batch_size, True)
            loader = torch.utils.data.DataLoader(
                ds,
                # batch_size=batch_size,
                batch_sampler=batch_sampler,
                # shuffle=True,
                num_workers=dataset_config.num_of_workers,
                pin_memory=True,
                # drop_last=True,
                worker_init_fn=worker_init_fn,
                timeout=60)
            self.data_loaders[name] = loader
Example #11
0
def test_hessian_kfac():
    model: u.SimpleMLP = u.SimpleMLP([2, 2], nonlin=True, bias=True)
    model.layers[0].weight.data.copy_(torch.eye(2))
    autograd_lib.register(model)
    loss_fn = torch.nn.CrossEntropyLoss()

    data = u.to_logits(torch.tensor([[0.7, 0.3]]))
    targets = torch.tensor([0])

    data = data.repeat([3, 1])
    targets = targets.repeat([3])
    n = len(data)

    activations = {}
    hessians = defaultdict(lambda: AttrDefault(float))

    for i in range(n):
        def save_activations(layer, A, _):
            activations[layer] = A
            hessians[layer].AA += torch.einsum("ni,nj->ij", A, A)

        with autograd_lib.module_hook(save_activations):
            data_batch = data[i: i+1]
            targets_batch = targets[i: i+1]
            Y = model(data_batch)
            loss = loss_fn(Y, targets_batch)

        def compute_hess(layer, _, B):
            hessians[layer].BB += torch.einsum("ni,nj->ij", B, B)

        with autograd_lib.module_hook(compute_hess):
            autograd_lib.backward_hessian(Y, loss='CrossEntropy', retain_graph=True)

    # check diagonal entries against autograd
    hess_autograd = u.hessian(loss, model.layers[0].weight)
    hess0_factored = hessians[model.layers[0]]

    diag_autograd = torch.einsum('lili->li', hess_autograd)
    diag_kfac = torch.einsum('ll,ii->li', hess0_factored.BB / n, hess0_factored.AA / n)
    u.check_close(diag_autograd,  diag_kfac)

    # check all entries against autograd
    hess0 = torch.einsum('kl,ij->kilj', hess0_factored.BB / n, hess0_factored.AA / n)
    u.check_close(hess_autograd, hess0)
Example #12
0
    def compute_layer_stats(layer):
        refreeze = False
        if hasattr(layer, 'frozen') and layer.frozen:
            u.unfreeze(layer)
            refreeze = True

        s = AttrDefault(str, {})
        n = args.stats_batch_size
        param = u.get_param(layer)
        _d = len(param.flatten())  # dimensionality of parameters
        layer_idx = model.layers.index(layer)
        # TODO: get layer type, include it in name
        assert layer_idx >= 0
        assert stats_data.shape[0] == n

        def backprop_loss():
            model.zero_grad()
            output = model(
                stats_data)  # use last saved data batch for backprop
            loss = compute_loss(output, stats_targets)
            loss.backward()
            return loss, output

        def backprop_output():
            model.zero_grad()
            output = model(stats_data)
            output.backward(gradient=torch.ones_like(output))
            return output

        # per-example gradients, n, d
        _loss, _output = backprop_loss()
        At = layer.data_input
        Bt = layer.grad_output * n
        G = u.khatri_rao_t(At, Bt)
        g = G.sum(dim=0, keepdim=True) / n
        u.check_close(g, u.vec(param.grad).t())

        s.diversity = torch.norm(G, "fro")**2 / g.flatten().norm()**2
        s.grad_fro = g.flatten().norm()
        s.param_fro = param.data.flatten().norm()
        pos_activations = torch.sum(layer.data_output > 0)
        neg_activations = torch.sum(layer.data_output <= 0)
        s.a_sparsity = neg_activations.float() / (
            pos_activations + neg_activations)  # 1 sparsity means all 0's
        activation_size = len(layer.data_output.flatten())
        s.a_magnitude = torch.sum(layer.data_output) / activation_size

        _output = backprop_output()
        B2t = layer.grad_output
        J = u.khatri_rao_t(At, B2t)  # batch output Jacobian
        H = J.t() @ J / n

        s.hessian_l2 = u.l2_norm(H)
        s.jacobian_l2 = u.l2_norm(J)
        J1 = J.sum(dim=0) / n  # single output Jacobian
        s.J1_l2 = J1.norm()

        # newton decrement
        def loss_direction(direction, eps):
            """loss improvement if we take step eps in direction dir"""
            return u.to_python_scalar(eps * (direction @ g.t()) - 0.5 *
                                      eps**2 * direction @ H @ direction.t())

        s.regret_newton = u.to_python_scalar(g @ u.pinv(H) @ g.t() / 2)

        # TODO: gradient diversity is stuck at 1
        # TODO: newton/gradient angle
        # TODO: newton step magnitude
        s.grad_curvature = u.to_python_scalar(
            g @ H @ g.t())  # curvature in direction of g
        s.step_openai = u.to_python_scalar(
            s.grad_fro**2 / s.grad_curvature) if s.grad_curvature else 999

        s.regret_gradient = loss_direction(g, s.step_openai)

        if refreeze:
            u.freeze(layer)
        return s
Example #13
0
 def __init__(self, config):
     self.config = AttrDefault(lambda: None, config)
Example #14
0
def main():
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size',
                        type=int,
                        default=64,
                        metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model',
                        action='store_true',
                        default=False,
                        help='For Saving the current Model')

    parser.add_argument('--wandb',
                        type=int,
                        default=1,
                        help='log to weights and biases')
    parser.add_argument('--autograd_check',
                        type=int,
                        default=0,
                        help='autograd correctness checks')
    parser.add_argument('--logdir',
                        type=str,
                        default='/tmp/runs/curv_train_tiny/run')

    parser.add_argument('--nonlin',
                        type=int,
                        default=1,
                        help="whether to add ReLU nonlinearity between layers")
    parser.add_argument('--bias',
                        type=int,
                        default=1,
                        help="whether to add bias between layers")

    parser.add_argument('--layer',
                        type=int,
                        default=-1,
                        help="restrict updates to this layer")
    parser.add_argument('--data_width', type=int, default=28)
    parser.add_argument('--targets_width', type=int, default=28)
    parser.add_argument(
        '--hess_samples',
        type=int,
        default=1,
        help='number of samples when sub-sampling outputs, 0 for exact hessian'
    )
    parser.add_argument('--hess_kfac',
                        type=int,
                        default=0,
                        help='whether to use KFAC approximation for hessian')
    parser.add_argument('--compute_rho',
                        type=int,
                        default=0,
                        help='use expensive method to compute rho')
    parser.add_argument('--skip_stats',
                        type=int,
                        default=1,
                        help='skip all stats collection')

    parser.add_argument('--dataset_size', type=int, default=60000)
    parser.add_argument('--train_steps',
                        type=int,
                        default=5,
                        help="this many train steps between stat collection")
    parser.add_argument('--stats_steps',
                        type=int,
                        default=1000000,
                        help="total number of curvature stats collections")

    parser.add_argument('--full_batch',
                        type=int,
                        default=0,
                        help='do stats on the whole dataset')
    parser.add_argument('--train_batch_size', type=int, default=64)
    parser.add_argument('--stats_batch_size', type=int, default=10000)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--weight_decay', type=float, default=1e-5)
    parser.add_argument('--momentum', type=float, default=0.9)
    parser.add_argument('--dropout', type=int, default=0)
    parser.add_argument('--swa', type=int, default=1)
    parser.add_argument('--lmb', type=float, default=1e-3)
    parser.add_argument('--uniform',
                        type=int,
                        default=0,
                        help="all layers same size")
    parser.add_argument('--redundancy',
                        type=int,
                        default=0,
                        help="duplicate all layers this many times")
    args = parser.parse_args()

    attemp_count = 0
    while os.path.exists(f"{args.logdir}{attemp_count:02d}"):
        attemp_count += 1
    logdir = f"{args.logdir}{attemp_count:02d}"

    run_name = os.path.basename(logdir)
    gl.event_writer = SummaryWriter(logdir)
    print(f"Logging to {run_name}")
    u.seed_random(1)

    d1 = 28 * 28
    if args.uniform:
        d = [784, 784, 784, 784, 784, 784, 10]
    else:
        d = [784, 2500, 2000, 1500, 1000, 500, 10]
    o = 10
    n = args.stats_batch_size
    if args.redundancy:
        model = u.RedundantFullyConnected2(d,
                                           nonlin=args.nonlin,
                                           bias=args.bias,
                                           dropout=args.dropout,
                                           redundancy=args.redundancy)
    else:
        model = u.SimpleFullyConnected2(d,
                                        nonlin=args.nonlin,
                                        bias=args.bias,
                                        dropout=args.dropout)
    model = model.to(gl.device)

    try:
        # os.environ['WANDB_SILENT'] = 'true'
        if args.wandb:
            wandb.init(project='train_ciresan', name=run_name)
            wandb.tensorboard.patch(tensorboardX=False)
            wandb.config['train_batch'] = args.train_batch_size
            wandb.config['stats_batch'] = args.stats_batch_size
            wandb.config['redundancy'] = args.redundancy
    except Exception as e:
        print(f"wandb crash with {e}")

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=args.momentum)
    dataset = u.TinyMNIST(data_width=args.data_width,
                          targets_width=args.targets_width,
                          original_targets=True,
                          dataset_size=args.dataset_size)
    train_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.train_batch_size,
        shuffle=True,
        drop_last=True)
    train_iter = u.infinite_iter(train_loader)

    assert not args.full_batch, "fixme: validation still uses stats_iter"
    if not args.full_batch:
        stats_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=args.stats_batch_size,
            shuffle=False,
            drop_last=True)
        stats_iter = u.infinite_iter(stats_loader)
    else:
        stats_iter = None

    test_dataset = u.TinyMNIST(data_width=args.data_width,
                               targets_width=args.targets_width,
                               train=False,
                               original_targets=True,
                               dataset_size=args.dataset_size)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.stats_batch_size,
                                              shuffle=False,
                                              drop_last=False)

    loss_fn = torch.nn.CrossEntropyLoss()

    gl.token_count = 0
    last_outer = 0
    for step in range(args.stats_steps):
        epoch = gl.token_count // 60000
        print(gl.token_count)
        if last_outer:
            u.log_scalars(
                {"time/outer": 1000 * (time.perf_counter() - last_outer)})
        last_outer = time.perf_counter()

        # compute validation loss
        model.eval()
        if args.swa:
            with u.timeit('swa'):
                base_opt = torch.optim.SGD(model.parameters(),
                                           lr=args.lr,
                                           momentum=args.momentum)
                opt = torchcontrib.optim.SWA(base_opt,
                                             swa_start=0,
                                             swa_freq=1,
                                             swa_lr=args.lr)
                for _ in range(100):
                    optimizer.zero_grad()
                    data, targets = next(train_iter)
                    model.zero_grad()
                    output = model(data)
                    loss = loss_fn(output, targets)
                    loss.backward()
                    opt.step()
                opt.swap_swa_sgd()

        with u.timeit("validate"):
            val_accuracy, val_loss = validate(model, test_loader,
                                              f'test (epoch {epoch})')
            train_accuracy, train_loss = validate(model, stats_loader,
                                                  f'train (epoch {epoch})')

        # save log
        metrics = {
            'epoch': epoch,
            'val_accuracy': val_accuracy,
            'val_loss': val_loss,
            'train_loss': train_loss,
            'train_accuracy': train_accuracy,
            'lr': optimizer.param_groups[0]['lr'],
            'momentum': optimizer.param_groups[0].get('momentum', 0)
        }
        u.log_scalars(metrics)

        # compute stats
        if args.full_batch:
            data, targets = dataset.data, dataset.targets
        else:
            data, targets = next(stats_iter)

        model.skip_forward_hooks = False
        model.skip_backward_hooks = False

        # get gradient values
        with u.timeit("backprop_g"):
            gl.backward_idx = 0
            u.clear_backprops(model)
            output = model(data)
            loss = loss_fn(output, targets)
            loss.backward(retain_graph=True)
        u.log_scalar(loss=loss.item())

        # get Hessian values
        hessian_activations = []
        hessian_backprops = []
        hessians = []  # list of Hessians in Kronecker form

        model.skip_forward_hooks = True
        for (i, layer) in enumerate(model.layers):
            if args.skip_stats:
                continue

            s = AttrDefault(str, {})  # dictionary-like object for layer stats

            #############################
            # Gradient stats
            #############################
            A_t = layer.activations
            assert A_t.shape == (n, d[i])

            # add factor of n because backprop takes loss averaged over batch, while we need per-example loss
            B_t = layer.backprops[0] * n
            assert B_t.shape == (n, d[i + 1])

            G = (B_t, A_t)
            #    g = G.sum(dim=0, keepdim=True) / n  # average gradient
            g = u.kron_sum(G) / n
            assert g.shape == (1, d[i] * d[i + 1])

            s.sparsity = torch.sum(layer.output <= 0) / layer.output.numel()
            s.mean_activation = torch.mean(A_t)
            s.mean_backprop = torch.mean(B_t)

            # empirical Fisher
            with u.timeit(f'sigma-{i}'):
                # efisher = u.kron_cov(G)  # G.t() @ G / n
                sigma = u.kron_sigma(G, g)  #  efisher - g.t() @ g
                s.sigma_l2 = u.kron_sym_l2_norm(sigma)
                s.sigma_erank = u.kron_trace(
                    sigma) / s.sigma_l2  # torch.trace(sigma)/s.sigma_l2

            #############################
            # Hessian stats
            #############################

            # this is a pair of left/right Kronecker fctors
            H = hessians[i]

            with u.timeit(f"invH-{i}"):
                invH = u.kron_inverse(H)

            with u.timeit(f"H_l2-{i}"):
                s.H_l2 = u.kron_sym_l2_norm(H)
                s.iH_l2 = u.kron_sym_l2_norm(invH)

            with u.timeit(f"norms-{i}"):
                s.H_fro = u.kron_fro_norm(H)
                s.invH_fro = u.kron_fro_norm(invH)
                s.grad_fro = u.kron_fro_norm(g)  # g.flatten().norm()
                s.param_fro = layer.weight.data.flatten().norm()

            u.kron_nan_check(H)

            with u.timeit(f"pinvH-{i}"):
                pinvH = u.kron_pinv(H)

            def kron_curv_direction(dd: torch.Tensor):
                """Curvature in direction dd, using factored form"""
                # dd @ H @ dd.t(), computed by kron_quadratic_form(H, dd)
                return u.to_python_scalar(
                    u.kron_quadratic_form(H, dd) / (dd.flatten().norm()**2))

            def kron_loss_direction(dd: torch.Tensor, eps):
                """loss improvement if we take step eps in direction dd"""

                # kron_matmul(dd, g) = dd @ g.t()
                return u.to_python_scalar(eps * (u.kron_matmul(dd, g)) -
                                          0.5 * eps**2 *
                                          u.kron_quadratic_form(H, dd))

            with u.timeit(f'curv-{i}'):
                s.grad_curv = kron_curv_direction(g)
                s.step_openai = 1 / s.grad_curv if s.grad_curv else 999
                s.step_max = 2 / s.H_l2
                s.step_min = torch.tensor(2) / u.kron_trace(H)

                s.regret_gradient = kron_loss_direction(g, s.step_openai)

            with u.timeit(f"batch-{i}"):
                # torch.trace(H @ sigma)                         # (g @ H @ g.t())
                s.batch_openai = u.kron_trace_matmul(
                    H, sigma) / u.kron_quadratic_form(H, g)
                s.diversity = torch.norm(G, "fro")**2 / torch.norm(g)**2

                # torch.trace(H)
                s.H_erank = u.kron_trace(H) / s.H_l2
                s.batch_jain_simple = 1 + s.H_erank

            u.log_scalars(u.nest_stats(layer.name, s))

        # gradient steps
        model.train()
        last_inner = 0
        for i in range(args.train_steps):
            if last_inner:
                u.log_scalars(
                    {"time/inner": 1000 * (time.perf_counter() - last_inner)})
            last_inner = time.perf_counter()

            optimizer.zero_grad()
            data, targets = next(train_iter)
            model.zero_grad()
            output = model(data)
            loss = loss_fn(output, targets)
            loss.backward()

            optimizer.step()
            if args.weight_decay:
                for group in optimizer.param_groups:
                    for param in group['params']:
                        param.data.mul_(1 - args.weight_decay)

            gl.token_count += data.shape[0]

    gl.event_writer.close()
Example #15
0
def main():

    u.install_pdb_handler()
    u.seed_random(1)
    logdir = u.create_local_logdir(args.logdir)
    run_name = os.path.basename(logdir)
    gl.event_writer = SummaryWriter(logdir)
    print(f"Logging to {logdir}")

    loss_type = 'CrossEntropy'

    d1 = args.data_width ** 2
    args.stats_batch_size = min(args.stats_batch_size, args.dataset_size)
    args.train_batch_size = min(args.train_batch_size, args.dataset_size)
    n = args.stats_batch_size
    o = 10
    d = [d1, 60, 60, 60, o]
    # dataset_size = args.dataset_size

    model = u.SimpleFullyConnected2(d, bias=True, nonlin=args.nonlin, last_layer_linear=True)
    model = model.to(gl.device)
    u.mark_expensive(model.layers[0])    # to stop grad1/hess calculations on this layer
    print(model)

    try:
        if args.wandb:
            wandb.init(project='curv_train_tiny', name=run_name, dir='/tmp/wandb.runs')
            wandb.tensorboard.patch(tensorboardX=False)
            wandb.config['train_batch'] = args.train_batch_size
            wandb.config['stats_batch'] = args.stats_batch_size
            wandb.config['n'] = n
    except Exception as e:
        print(f"wandb crash with {e}")

    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9)
    #  optimizer = torch.optim.Adam(model.parameters(), lr=0.03)  # make 10x smaller for least-squares loss
    dataset = u.TinyMNIST(data_width=args.data_width, dataset_size=args.dataset_size, loss_type=loss_type)

    train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=False, drop_last=True)
    train_iter = u.infinite_iter(train_loader)

    stats_loader = torch.utils.data.DataLoader(dataset, batch_size=args.stats_batch_size, shuffle=False, drop_last=True)
    stats_iter = u.infinite_iter(stats_loader)
    stats_data, stats_targets = next(stats_iter)

    test_dataset = u.TinyMNIST(data_width=args.data_width, train=False, dataset_size=args.dataset_size, loss_type=loss_type)
    test_batch_size = min(args.dataset_size, 1000)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False, drop_last=True)
    test_iter = u.infinite_iter(test_loader)

    if loss_type == 'LeastSquares':
        loss_fn = u.least_squares
    else:   # loss_type == 'CrossEntropy':
        loss_fn = nn.CrossEntropyLoss()

    autograd_lib.add_hooks(model)
    gl.reset_global_step()
    last_outer = 0
    val_losses = []
    for step in range(args.stats_steps):
        if last_outer:
            u.log_scalars({"time/outer": 1000*(time.perf_counter() - last_outer)})
        last_outer = time.perf_counter()

        with u.timeit("val_loss"):
            test_data, test_targets = next(test_iter)
            test_output = model(test_data)
            val_loss = loss_fn(test_output, test_targets)
            print("val_loss", val_loss.item())
            val_losses.append(val_loss.item())
            u.log_scalar(val_loss=val_loss.item())

        with u.timeit("validate"):
            if loss_type == 'CrossEntropy':
                val_accuracy, val_loss = validate(model, test_loader, f'test (stats_step {step})')
                # train_accuracy, train_loss = validate(model, train_loader, f'train (stats_step {step})')

                metrics = {'stats_step': step, 'val_accuracy': val_accuracy, 'val_loss': val_loss}
                u.log_scalars(metrics)

        data, targets = stats_data, stats_targets

        if not args.skip_stats:
            # Capture Hessian and gradient stats
            autograd_lib.enable_hooks()
            autograd_lib.clear_backprops(model)
            autograd_lib.clear_hess_backprops(model)
            with u.timeit("backprop_g"):
                output = model(data)
                loss = loss_fn(output, targets)
                loss.backward(retain_graph=True)
            with u.timeit("backprop_H"):
                autograd_lib.backprop_hess(output, hess_type=loss_type)
            autograd_lib.disable_hooks()   # TODO(y): use remove_hooks

            with u.timeit("compute_grad1"):
                autograd_lib.compute_grad1(model)
            with u.timeit("compute_hess"):
                autograd_lib.compute_hess(model)

            for (i, layer) in enumerate(model.layers):

                if hasattr(layer, 'expensive'):
                    continue

                param_names = {layer.weight: "weight", layer.bias: "bias"}
                for param in [layer.weight, layer.bias]:
                    # input/output layers are unreasonably expensive if not using Kronecker factoring
                    if d[i]*d[i+1] > 8000:
                        print(f'layer {i} is too big ({d[i],d[i+1]}), skipping stats')
                        continue

                    s = AttrDefault(str, {})  # dictionary-like object for layer stats

                    #############################
                    # Gradient stats
                    #############################
                    A_t = layer.activations
                    B_t = layer.backprops_list[0] * n
                    s.sparsity = torch.sum(layer.output <= 0) / layer.output.numel()  # proportion of activations that are zero
                    s.mean_activation = torch.mean(A_t)
                    s.mean_backprop = torch.mean(B_t)

                    # empirical Fisher
                    G = param.grad1.reshape((n, -1))
                    g = G.mean(dim=0, keepdim=True)

                    u.nan_check(G)
                    with u.timeit(f'sigma-{i}'):
                        efisher = G.t() @ G / n
                        sigma = efisher - g.t() @ g
                        # sigma_spectrum =
                        s.sigma_l2 = u.sym_l2_norm(sigma)
                        s.sigma_erank = torch.trace(sigma)/s.sigma_l2

                    H = param.hess
                    lambda_regularizer = args.lmb * torch.eye(H.shape[0]).to(gl.device)
                    u.nan_check(H)

                    with u.timeit(f"invH-{i}"):
                        invH = torch.cholesky_inverse(H+lambda_regularizer)

                    with u.timeit(f"H_l2-{i}"):
                        s.H_l2 = u.sym_l2_norm(H)
                        s.iH_l2 = u.sym_l2_norm(invH)

                    with u.timeit(f"norms-{i}"):
                        s.H_fro = H.flatten().norm()
                        s.iH_fro = invH.flatten().norm()
                        s.grad_fro = g.flatten().norm()
                        s.param_fro = param.data.flatten().norm()

                    def loss_direction(dd: torch.Tensor, eps):
                        """loss improvement if we take step eps in direction dd"""
                        return u.to_python_scalar(eps * (dd @ g.t()) - 0.5 * eps ** 2 * dd @ H @ dd.t())

                    def curv_direction(dd: torch.Tensor):
                        """Curvature in direction dd"""
                        return u.to_python_scalar(dd @ H @ dd.t() / (dd.flatten().norm() ** 2))

                    with u.timeit(f"pinvH-{i}"):
                        pinvH = u.pinv(H)

                    with u.timeit(f'curv-{i}'):
                        s.grad_curv = curv_direction(g)  # curvature (eigenvalue) in direction g
                        ndir = g @ pinvH  # newton direction
                        s.newton_curv = curv_direction(ndir)
                        setattr(layer.weight, 'pre', pinvH)  # save Newton preconditioner
                        s.step_openai = 1 / s.grad_curv if s.grad_curv else 1234567
                        s.step_div_inf = 2 / s.H_l2         # divegent step size for batch_size=infinity
                        s.step_div_1 = torch.tensor(2) / torch.trace(H)   # divergent step for batch_size=1

                        s.newton_fro = ndir.flatten().norm()  # frobenius norm of Newton update
                        s.regret_newton = u.to_python_scalar(g @ pinvH @ g.t() / 2)   # replace with "quadratic_form"
                        s.regret_gradient = loss_direction(g, s.step_openai)

                    with u.timeit(f'rho-{i}'):
                        s.rho, s.lyap_erank, lyap_evals = u.truncated_lyapunov_rho(H, sigma)
                        s.step_div_1_adjusted = s.step_div_1/s.rho

                    with u.timeit(f"batch-{i}"):
                        s.batch_openai = torch.trace(H @ sigma) / (g @ H @ g.t())
                        s.diversity = torch.norm(G, "fro") ** 2 / torch.norm(g) ** 2 / n  # Gradient diversity / n
                        s.noise_variance_pinv = torch.trace(pinvH @ sigma)
                        s.H_erank = torch.trace(H) / s.H_l2
                        s.batch_jain_simple = 1 + s.H_erank
                        s.batch_jain_full = 1 + s.rho * s.H_erank

                    param_name = f"{layer.name}={param_names[param]}"
                    u.log_scalars(u.nest_stats(f"{param_name}", s))

                    H_evals = u.symeig_pos_evals(H)
                    sigma_evals = u.symeig_pos_evals(sigma)
                    u.log_spectrum(f'{param_name}/hess', H_evals)
                    u.log_spectrum(f'{param_name}/sigma', sigma_evals)
                    u.log_spectrum(f'{param_name}/lyap', lyap_evals)

        # gradient steps
        with u.timeit('inner'):
            for i in range(args.train_steps):
                optimizer.zero_grad()
                data, targets = next(train_iter)
                model.zero_grad()
                output = model(data)
                loss = loss_fn(output, targets)
                loss.backward()

                optimizer.step()
                if args.weight_decay:
                    for group in optimizer.param_groups:
                        for param in group['params']:
                            param.data.mul_(1-args.weight_decay)

                gl.increment_global_step(data.shape[0])

    gl.event_writer.close()
Example #16
0
def test_main():

    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--test-batch-size',
                        type=int,
                        default=1000,
                        metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs',
                        type=int,
                        default=10,
                        metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.01,
                        metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum',
                        type=float,
                        default=0.5,
                        metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--no-cuda',
                        action='store_true',
                        default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed',
                        type=int,
                        default=1,
                        metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument(
        '--log-interval',
        type=int,
        default=10,
        metavar='N',
        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model',
                        action='store_true',
                        default=False,
                        help='For Saving the current Model')

    parser.add_argument('--wandb',
                        type=int,
                        default=1,
                        help='log to weights and biases')
    parser.add_argument('--autograd_check',
                        type=int,
                        default=0,
                        help='autograd correctness checks')
    parser.add_argument('--logdir',
                        type=str,
                        default='/temp/runs/curv_train_tiny/run')

    parser.add_argument('--train_batch_size', type=int, default=100)
    parser.add_argument('--stats_batch_size', type=int, default=60000)
    parser.add_argument('--dataset_size', type=int, default=60000)
    parser.add_argument('--train_steps',
                        type=int,
                        default=100,
                        help="this many train steps between stat collection")
    parser.add_argument('--stats_steps',
                        type=int,
                        default=1000000,
                        help="total number of curvature stats collections")
    parser.add_argument('--nonlin',
                        type=int,
                        default=1,
                        help="whether to add ReLU nonlinearity between layers")
    parser.add_argument('--method',
                        type=str,
                        choices=['gradient', 'newton'],
                        default='gradient',
                        help="descent method, newton or gradient")
    parser.add_argument('--layer',
                        type=int,
                        default=-1,
                        help="restrict updates to this layer")
    parser.add_argument('--data_width', type=int, default=28)
    parser.add_argument('--targets_width', type=int, default=28)
    parser.add_argument('--lmb', type=float, default=1e-3)
    parser.add_argument(
        '--hess_samples',
        type=int,
        default=1,
        help='number of samples when sub-sampling outputs, 0 for exact hessian'
    )
    parser.add_argument('--hess_kfac',
                        type=int,
                        default=0,
                        help='whether to use KFAC approximation for hessian')
    parser.add_argument('--compute_rho',
                        type=int,
                        default=1,
                        help='use expensive method to compute rho')
    parser.add_argument('--skip_stats',
                        type=int,
                        default=0,
                        help='skip all stats collection')
    parser.add_argument('--full_batch',
                        type=int,
                        default=0,
                        help='do stats on the whole dataset')
    parser.add_argument('--weight_decay', type=float, default=1e-4)

    #args = parser.parse_args()
    args = AttrDict()
    args.lmb = 1e-3
    args.compute_rho = 1
    args.weight_decay = 1e-4
    args.method = 'gradient'
    args.logdir = '/tmp'
    args.data_width = 2
    args.targets_width = 2
    args.train_batch_size = 10
    args.full_batch = False
    args.skip_stats = False
    args.autograd_check = False

    u.seed_random(1)
    logdir = u.create_local_logdir(args.logdir)
    run_name = os.path.basename(logdir)
    #gl.event_writer = SummaryWriter(logdir)
    gl.event_writer = u.NoOp()
    # print(f"Logging to {run_name}")

    # small values for debugging
    # loss_type = 'LeastSquares'
    loss_type = 'CrossEntropy'

    args.wandb = 0
    args.stats_steps = 10
    args.train_steps = 10
    args.stats_batch_size = 10
    args.data_width = 2
    args.targets_width = 2
    args.nonlin = False
    d1 = args.data_width**2
    d2 = 2
    d3 = args.targets_width**2

    d1 = args.data_width**2
    assert args.data_width == args.targets_width
    o = d1
    n = args.stats_batch_size
    d = [d1, 30, 30, 30, 20, 30, 30, 30, d1]

    if loss_type == 'CrossEntropy':
        d3 = 10
    o = d3
    n = args.stats_batch_size
    d = [d1, d2, d3]
    dsize = max(args.train_batch_size, args.stats_batch_size) + 1

    model = u.SimpleFullyConnected2(d, bias=True, nonlin=args.nonlin)
    model = model.to(gl.device)

    try:
        # os.environ['WANDB_SILENT'] = 'true'
        if args.wandb:
            wandb.init(project='curv_train_tiny', name=run_name)
            wandb.tensorboard.patch(tensorboardX=False)
            wandb.config['train_batch'] = args.train_batch_size
            wandb.config['stats_batch'] = args.stats_batch_size
            wandb.config['method'] = args.method
            wandb.config['n'] = n
    except Exception as e:
        print(f"wandb crash with {e}")

    # optimizer = torch.optim.SGD(model.parameters(), lr=0.03, momentum=0.9)
    optimizer = torch.optim.Adam(
        model.parameters(), lr=0.03)  # make 10x smaller for least-squares loss
    dataset = u.TinyMNIST(data_width=args.data_width,
                          targets_width=args.targets_width,
                          dataset_size=dsize,
                          original_targets=True)

    train_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.train_batch_size,
        shuffle=False,
        drop_last=True)
    train_iter = u.infinite_iter(train_loader)

    stats_iter = None
    if not args.full_batch:
        stats_loader = torch.utils.data.DataLoader(
            dataset,
            batch_size=args.stats_batch_size,
            shuffle=False,
            drop_last=True)
        stats_iter = u.infinite_iter(stats_loader)

    test_dataset = u.TinyMNIST(data_width=args.data_width,
                               targets_width=args.targets_width,
                               train=False,
                               dataset_size=dsize,
                               original_targets=True)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.train_batch_size,
                                              shuffle=False,
                                              drop_last=True)
    test_iter = u.infinite_iter(test_loader)

    if loss_type == 'LeastSquares':
        loss_fn = u.least_squares
    elif loss_type == 'CrossEntropy':
        loss_fn = nn.CrossEntropyLoss()

    autograd_lib.add_hooks(model)
    gl.token_count = 0
    last_outer = 0
    val_losses = []
    for step in range(args.stats_steps):
        if last_outer:
            u.log_scalars(
                {"time/outer": 1000 * (time.perf_counter() - last_outer)})
        last_outer = time.perf_counter()

        with u.timeit("val_loss"):
            test_data, test_targets = next(test_iter)
            test_output = model(test_data)
            val_loss = loss_fn(test_output, test_targets)
            # print("val_loss", val_loss.item())
            val_losses.append(val_loss.item())
            u.log_scalar(val_loss=val_loss.item())

        # compute stats
        if args.full_batch:
            data, targets = dataset.data, dataset.targets
        else:
            data, targets = next(stats_iter)

        # Capture Hessian and gradient stats
        autograd_lib.enable_hooks()
        autograd_lib.clear_backprops(model)
        autograd_lib.clear_hess_backprops(model)
        with u.timeit("backprop_g"):
            output = model(data)
            loss = loss_fn(output, targets)
            loss.backward(retain_graph=True)
        with u.timeit("backprop_H"):
            autograd_lib.backprop_hess(output, hess_type=loss_type)
        autograd_lib.disable_hooks()  # TODO(y): use remove_hooks

        with u.timeit("compute_grad1"):
            autograd_lib.compute_grad1(model)
        with u.timeit("compute_hess"):
            autograd_lib.compute_hess(model)

        for (i, layer) in enumerate(model.layers):

            # input/output layers are unreasonably expensive if not using Kronecker factoring
            if d[i] > 50 or d[i + 1] > 50:
                print(
                    f'layer {i} is too big ({d[i], d[i + 1]}), skipping stats')
                continue

            if args.skip_stats:
                continue

            s = AttrDefault(str, {})  # dictionary-like object for layer stats

            #############################
            # Gradient stats
            #############################
            A_t = layer.activations
            assert A_t.shape == (n, d[i])

            # add factor of n because backprop takes loss averaged over batch, while we need per-example loss
            B_t = layer.backprops_list[0] * n
            assert B_t.shape == (n, d[i + 1])

            with u.timeit(f"khatri_g-{i}"):
                G = u.khatri_rao_t(B_t, A_t)  # batch loss Jacobian
            assert G.shape == (n, d[i] * d[i + 1])
            g = G.sum(dim=0, keepdim=True) / n  # average gradient
            assert g.shape == (1, d[i] * d[i + 1])

            u.check_equal(G.reshape(layer.weight.grad1.shape),
                          layer.weight.grad1)

            if args.autograd_check:
                u.check_close(B_t.t() @ A_t / n, layer.weight.saved_grad)
                u.check_close(g.reshape(d[i + 1], d[i]),
                              layer.weight.saved_grad)

            s.sparsity = torch.sum(layer.output <= 0) / layer.output.numel(
            )  # proportion of activations that are zero
            s.mean_activation = torch.mean(A_t)
            s.mean_backprop = torch.mean(B_t)

            # empirical Fisher
            with u.timeit(f'sigma-{i}'):
                efisher = G.t() @ G / n
                sigma = efisher - g.t() @ g
                s.sigma_l2 = u.sym_l2_norm(sigma)
                s.sigma_erank = torch.trace(sigma) / s.sigma_l2

            lambda_regularizer = args.lmb * torch.eye(d[i + 1] * d[i]).to(
                gl.device)
            H = layer.weight.hess

            with u.timeit(f"invH-{i}"):
                invH = torch.cholesky_inverse(H + lambda_regularizer)

            with u.timeit(f"H_l2-{i}"):
                s.H_l2 = u.sym_l2_norm(H)
                s.iH_l2 = u.sym_l2_norm(invH)

            with u.timeit(f"norms-{i}"):
                s.H_fro = H.flatten().norm()
                s.iH_fro = invH.flatten().norm()
                s.grad_fro = g.flatten().norm()
                s.param_fro = layer.weight.data.flatten().norm()

            u.nan_check(H)
            if args.autograd_check:
                model.zero_grad()
                output = model(data)
                loss = loss_fn(output, targets)
                H_autograd = u.hessian(loss, layer.weight)
                H_autograd = H_autograd.reshape(d[i] * d[i + 1],
                                                d[i] * d[i + 1])
                u.check_close(H, H_autograd)

            #  u.dump(sigma, f'/tmp/sigmas/H-{step}-{i}')
            def loss_direction(dd: torch.Tensor, eps):
                """loss improvement if we take step eps in direction dd"""
                return u.to_python_scalar(eps * (dd @ g.t()) -
                                          0.5 * eps**2 * dd @ H @ dd.t())

            def curv_direction(dd: torch.Tensor):
                """Curvature in direction dd"""
                return u.to_python_scalar(dd @ H @ dd.t() /
                                          (dd.flatten().norm()**2))

            with u.timeit(f"pinvH-{i}"):
                pinvH = H.pinverse()

            with u.timeit(f'curv-{i}'):
                s.grad_curv = curv_direction(g)
                ndir = g @ pinvH  # newton direction
                s.newton_curv = curv_direction(ndir)
                setattr(layer.weight, 'pre',
                        pinvH)  # save Newton preconditioner
                s.step_openai = s.grad_fro**2 / s.grad_curv if s.grad_curv else 999
                s.step_max = 2 / s.H_l2
                s.step_min = torch.tensor(2) / torch.trace(H)

                s.newton_fro = ndir.flatten().norm(
                )  # frobenius norm of Newton update
                s.regret_newton = u.to_python_scalar(
                    g @ pinvH @ g.t() / 2)  # replace with "quadratic_form"
                s.regret_gradient = loss_direction(g, s.step_openai)

            with u.timeit(f'rho-{i}'):
                p_sigma = u.lyapunov_spectral(H, sigma)

                discrepancy = torch.max(abs(p_sigma - p_sigma.t()) / p_sigma)

                s.psigma_erank = u.sym_erank(p_sigma)
                s.rho = H.shape[0] / s.psigma_erank

            with u.timeit(f"batch-{i}"):
                s.batch_openai = torch.trace(H @ sigma) / (g @ H @ g.t())
                s.diversity = torch.norm(G, "fro")**2 / torch.norm(g)**2 / n

                # Faster approaches for noise variance computation
                # s.noise_variance = torch.trace(H.inverse() @ sigma)
                # try:
                #     # this fails with singular sigma
                #     s.noise_variance = torch.trace(torch.solve(sigma, H)[0])
                #     # s.noise_variance = torch.trace(torch.lstsq(sigma, H)[0])
                #     pass
                # except RuntimeError as _:
                s.noise_variance_pinv = torch.trace(pinvH @ sigma)

                s.H_erank = torch.trace(H) / s.H_l2
                s.batch_jain_simple = 1 + s.H_erank
                s.batch_jain_full = 1 + s.rho * s.H_erank

            u.log_scalars(u.nest_stats(layer.name, s))

        # gradient steps
        with u.timeit('inner'):
            for i in range(args.train_steps):
                optimizer.zero_grad()
                data, targets = next(train_iter)
                model.zero_grad()
                output = model(data)
                loss = loss_fn(output, targets)
                loss.backward()

                #            u.log_scalar(train_loss=loss.item())

                if args.method != 'newton':
                    optimizer.step()
                    if args.weight_decay:
                        for group in optimizer.param_groups:
                            for param in group['params']:
                                param.data.mul_(1 - args.weight_decay)
                else:
                    for (layer_idx, layer) in enumerate(model.layers):
                        param: torch.nn.Parameter = layer.weight
                        param_data: torch.Tensor = param.data
                        param_data.copy_(param_data - 0.1 * param.grad)
                        if layer_idx != 1:  # only update 1 layer with Newton, unstable otherwise
                            continue
                        u.nan_check(layer.weight.pre)
                        u.nan_check(param.grad.flatten())
                        u.nan_check(
                            u.v2r(param.grad.flatten()) @ layer.weight.pre)
                        param_new_flat = u.v2r(param_data.flatten()) - u.v2r(
                            param.grad.flatten()) @ layer.weight.pre
                        u.nan_check(param_new_flat)
                        param_data.copy_(
                            param_new_flat.reshape(param_data.shape))

                gl.token_count += data.shape[0]

    gl.event_writer.close()

    assert val_losses[0] > 2.4  # 2.4828238487243652
    assert val_losses[-1] < 2.25  # 2.20609712600708
    def compute_layer_stats(layer):
        stats = AttrDefault(str, {})
        n = stats_batch_size
        param = u.get_param(layer)
        d = len(param.flatten())
        layer_idx = model.layers.index(layer)
        assert layer_idx >= 0
        assert stats_data.shape[0] == n

        def backprop_loss():
            model.zero_grad()
            output = model(
                stats_data)  # use last saved data batch for backprop
            loss = compute_loss(output, stats_targets)
            loss.backward()
            return loss, output

        def backprop_output():
            model.zero_grad()
            output = model(stats_data)
            output.backward(gradient=torch.ones_like(output))
            return output

        # per-example gradients, n, d
        loss, output = backprop_loss()
        At = layer.data_input
        Bt = layer.grad_output * n
        G = u.khatri_rao_t(At, Bt)
        g = G.sum(dim=0, keepdim=True) / n
        u.check_close(g, u.vec(param.grad).t())

        stats.diversity = torch.norm(G, "fro")**2 / g.flatten().norm()**2

        stats.gradient_norm = g.flatten().norm()
        stats.parameter_norm = param.data.flatten().norm()
        pos_activations = torch.sum(layer.data_output > 0)
        neg_activations = torch.sum(layer.data_output <= 0)
        stats.sparsity = pos_activations.float() / (pos_activations +
                                                    neg_activations)

        output = backprop_output()
        At2 = layer.data_input
        u.check_close(At, At2)
        B2t = layer.grad_output
        J = u.khatri_rao_t(At, B2t)
        H = J.t() @ J / n

        model.zero_grad()
        output = model(stats_data)  # use last saved data batch for backprop
        loss = compute_loss(output, stats_targets)
        hess = u.hessian(loss, param)

        hess = hess.transpose(2, 3).transpose(0, 1).reshape(d, d)
        u.check_close(hess, H)
        u.check_close(hess, H)

        stats.hessian_norm = u.l2_norm(H)
        stats.jacobian_norm = u.l2_norm(J)
        Joutput = J.sum(dim=0) / n
        stats.jacobian_sensitivity = Joutput.norm()

        # newton decrement
        stats.loss_newton = u.to_python_scalar(g @ u.pinv(H) @ g.t() / 2)
        u.check_close(stats.loss_newton, loss)

        # do line-search to find optimal step
        def line_search(directionv, start, end, steps=10):
            """Takes steps between start and end, returns steps+1 loss entries"""
            param0 = param.data.clone()
            param0v = u.vec(param0).t()
            losses = []
            for i in range(steps + 1):
                output = model(
                    stats_data)  # use last saved data batch for backprop
                loss = compute_loss(output, stats_targets)
                losses.append(loss)
                offset = start + i * ((end - start) / steps)
                param1v = param0v + offset * directionv

                param1 = u.unvec(param1v.t(), param.data.shape[0])
                param.data.copy_(param1)

            output = model(
                stats_data)  # use last saved data batch for backprop
            loss = compute_loss(output, stats_targets)
            losses.append(loss)

            param.data.copy_(param0)
            return losses

        # try to take a newton step
        gradv = g
        line_losses = line_search(-gradv @ u.pinv(H), 0, 2, steps=10)
        u.check_equal(line_losses[0], loss)
        u.check_equal(line_losses[6], 0)
        assert line_losses[5] > line_losses[6]
        assert line_losses[7] > line_losses[6]
        return stats
Example #18
0
def main():
    attemp_count = 0
    while os.path.exists(f"{args.logdir}{attemp_count:02d}"):
        attemp_count += 1
    logdir = f"{args.logdir}{attemp_count:02d}"

    run_name = os.path.basename(logdir)
    gl.event_writer = SummaryWriter(logdir)
    print(f"Logging to {run_name}")
    u.seed_random(1)

    d1 = args.data_width**2
    d2 = 10
    d3 = args.targets_width**2
    o = d3
    n = args.stats_batch_size
    d = [d1, d2, d3]
    model = u.SimpleFullyConnected(d, nonlin=args.nonlin)
    model = model.to(gl.device)

    try:
        # os.environ['WANDB_SILENT'] = 'true'
        if args.wandb:
            wandb.init(project='curv_train_tiny', name=run_name)
            wandb.tensorboard.patch(tensorboardX=False)
            wandb.config['train_batch'] = args.train_batch_size
            wandb.config['stats_batch'] = args.stats_batch_size
            wandb.config['method'] = args.method
            wandb.config['d1'] = d1
            wandb.config['d2'] = d2
            wandb.config['d3'] = d3
            wandb.config['n'] = n
    except Exception as e:
        print(f"wandb crash with {e}")

    optimizer = torch.optim.SGD(model.parameters(), lr=0.03, momentum=0.9)

    dataset = u.TinyMNIST(data_width=args.data_width,
                          targets_width=args.targets_width,
                          dataset_size=args.dataset_size)

    train_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.train_batch_size,
        shuffle=False,
        drop_last=True)
    train_iter = u.infinite_iter(train_loader)

    stats_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.stats_batch_size,
        shuffle=False,
        drop_last=True)
    stats_iter = u.infinite_iter(stats_loader)

    test_dataset = u.TinyMNIST(data_width=args.data_width,
                               targets_width=args.targets_width,
                               dataset_size=args.dataset_size,
                               train=False)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.stats_batch_size,
                                              shuffle=True,
                                              drop_last=True)
    test_iter = u.infinite_iter(test_loader)

    skip_forward_hooks = False
    skip_backward_hooks = False

    def capture_activations(module: nn.Module, input: List[torch.Tensor],
                            output: torch.Tensor):
        if skip_forward_hooks:
            return
        assert not hasattr(
            module, 'activations'
        ), "Seeing results of previous autograd, call util.zero_grad to clear"
        assert len(input) == 1, "this was tested for single input layers only"
        setattr(module, "activations", input[0].detach())
        setattr(module, "output", output.detach())

    def capture_backprops(module: nn.Module, _input, output):
        if skip_backward_hooks:
            return
        assert len(output) == 1, "this works for single variable layers only"
        if gl.backward_idx == 0:
            assert not hasattr(
                module, 'backprops'
            ), "Seeing results of previous autograd, call util.zero_grad to clear"
            setattr(module, 'backprops', [])
        assert gl.backward_idx == len(module.backprops)
        module.backprops.append(output[0])

    def save_grad(param: nn.Parameter) -> Callable[[torch.Tensor], None]:
        """Hook to save gradient into 'param.saved_grad', so it can be accessed after model.zero_grad(). Only stores gradient
        if the value has not been set, call util.zero_grad to clear it."""
        def save_grad_fn(grad):
            if not hasattr(param, 'saved_grad'):
                setattr(param, 'saved_grad', grad)

        return save_grad_fn

    for layer in model.layers:
        layer.register_forward_hook(capture_activations)
        layer.register_backward_hook(capture_backprops)
        layer.weight.register_hook(save_grad(layer.weight))

    def loss_fn(data, targets):
        err = data - targets.view(-1, data.shape[1])
        assert len(data) == len(targets)
        return torch.sum(err * err) / 2 / len(data)

    gl.token_count = 0
    last_outer = 0
    for step in range(args.stats_steps):
        if last_outer:
            u.log_scalars(
                {"time/outer": 1000 * (time.perf_counter() - last_outer)})
        last_outer = time.perf_counter()
        # compute validation loss
        skip_forward_hooks = True
        skip_backward_hooks = True
        with u.timeit("val_loss"):
            test_data, test_targets = next(test_iter)
            test_output = model(test_data)
            val_loss = loss_fn(test_output, test_targets)
            print("val_loss", val_loss.item())
            u.log_scalar(val_loss=val_loss.item())

        # compute stats
        data, targets = next(stats_iter)
        skip_forward_hooks = False
        skip_backward_hooks = False

        # get gradient values
        with u.timeit("backprop_g"):
            gl.backward_idx = 0
            u.zero_grad(model)
            output = model(data)
            loss = loss_fn(output, targets)
            loss.backward(retain_graph=True)

        # get Hessian values
        skip_forward_hooks = True
        id_mat = torch.eye(o).to(gl.device)

        u.log_scalar(loss=loss.item())

        with u.timeit("backprop_H"):
            # optionally use randomized low-rank approximation of Hessian
            hess_rank = args.hess_samples if args.hess_samples else o

            for out_idx in range(hess_rank):
                model.zero_grad()
                # backprop to get section of batch output jacobian for output at position out_idx
                output = model(
                    data
                )  # opt: using autograd.grad means I don't have to zero_grad
                if args.hess_samples:
                    bval = torch.LongTensor(n, o).to(gl.device).random_(
                        0, 2) * 2 - 1
                    bval = bval.float()
                else:
                    ei = id_mat[out_idx]
                    bval = torch.stack([ei] * n)
                gl.backward_idx = out_idx + 1
                output.backward(bval)
            skip_backward_hooks = True  #

        for (i, layer) in enumerate(model.layers):
            s = AttrDefault(str, {})  # dictionary-like object for layer stats

            #############################
            # Gradient stats
            #############################
            A_t = layer.activations
            assert A_t.shape == (n, d[i])

            # add factor of n because backprop takes loss averaged over batch, while we need per-example loss
            B_t = layer.backprops[0] * n
            assert B_t.shape == (n, d[i + 1])

            with u.timeit(f"khatri_g-{i}"):
                G = u.khatri_rao_t(B_t, A_t)  # batch loss Jacobian
            assert G.shape == (n, d[i] * d[i + 1])
            g = G.sum(dim=0, keepdim=True) / n  # average gradient
            assert g.shape == (1, d[i] * d[i + 1])

            if args.autograd_check:
                u.check_close(B_t.t() @ A_t / n, layer.weight.saved_grad)
                u.check_close(g.reshape(d[i + 1], d[i]),
                              layer.weight.saved_grad)

            s.sparsity = torch.sum(layer.output <= 0) / layer.output.numel()
            s.mean_activation = torch.mean(A_t)
            s.mean_backprop = torch.mean(B_t)

            # empirical Fisher
            with u.timeit(f'sigma-{i}'):
                efisher = G.t() @ G / n
                sigma = efisher - g.t() @ g
                s.sigma_l2 = u.sym_l2_norm(sigma)
                s.sigma_erank = torch.trace(sigma) / s.sigma_l2

            #############################
            # Hessian stats
            #############################
            A_t = layer.activations
            Bh_t = [
                layer.backprops[out_idx + 1] for out_idx in range(hess_rank)
            ]
            Amat_t = torch.cat([A_t] * hess_rank, dim=0)
            Bmat_t = torch.cat(Bh_t, dim=0)

            assert Amat_t.shape == (n * hess_rank, d[i])
            assert Bmat_t.shape == (n * hess_rank, d[i + 1])

            lambda_regularizer = args.lmb * torch.eye(d[i] * d[i + 1]).to(
                gl.device)
            with u.timeit(f"khatri_H-{i}"):
                Jb = u.khatri_rao_t(
                    Bmat_t, Amat_t)  # batch Jacobian, in row-vec format

            with u.timeit(f"H-{i}"):
                H = Jb.t() @ Jb / n

            with u.timeit(f"invH-{i}"):
                invH = torch.cholesky_inverse(H + lambda_regularizer)

            with u.timeit(f"H_l2-{i}"):
                s.H_l2 = u.sym_l2_norm(H)
                s.iH_l2 = u.sym_l2_norm(invH)

            with u.timeit(f"norms-{i}"):
                s.H_fro = H.flatten().norm()
                s.iH_fro = invH.flatten().norm()
                s.jacobian_fro = Jb.flatten().norm()
                s.grad_fro = g.flatten().norm()
                s.param_fro = layer.weight.data.flatten().norm()

            u.nan_check(H)
            if args.autograd_check:
                model.zero_grad()
                output = model(data)
                loss = loss_fn(output, targets)
                H_autograd = u.hessian(loss, layer.weight)
                H_autograd = H_autograd.reshape(d[i] * d[i + 1],
                                                d[i] * d[i + 1])
                u.check_close(H, H_autograd)

            #  u.dump(sigma, f'/tmp/sigmas/H-{step}-{i}')
            def loss_direction(dd: torch.Tensor, eps):
                """loss improvement if we take step eps in direction dd"""
                return u.to_python_scalar(eps * (dd @ g.t()) -
                                          0.5 * eps**2 * dd @ H @ dd.t())

            def curv_direction(dd: torch.Tensor):
                """Curvature in direction dd"""
                return u.to_python_scalar(dd @ H @ dd.t() /
                                          (dd.flatten().norm()**2))

            with u.timeit("pinvH"):
                pinvH = u.pinv(H)

            with u.timeit(f'curv-{i}'):
                s.regret_newton = u.to_python_scalar(g @ pinvH @ g.t() / 2)
                s.grad_curv = curv_direction(g)
                ndir = g @ pinvH  # newton direction
                s.newton_curv = curv_direction(ndir)
                setattr(layer.weight, 'pre',
                        pinvH)  # save Newton preconditioner
                s.step_openai = 1 / s.grad_curv if s.grad_curv else 999
                s.step_max = 2 / u.sym_l2_norm(H)
                s.step_min = torch.tensor(2) / torch.trace(H)

                s.newton_fro = ndir.flatten().norm(
                )  # frobenius norm of Newton update
                s.regret_gradient = loss_direction(g, s.step_openai)

            with u.timeit(f'rho-{i}'):
                p_sigma = u.lyapunov_svd(H, sigma)
                if u.has_nan(
                        p_sigma) and args.compute_rho:  # use expensive method
                    H0 = H.cpu().detach().numpy()
                    sigma0 = sigma.cpu().detach().numpy()
                    p_sigma = scipy.linalg.solve_lyapunov(H0, sigma0)
                    p_sigma = torch.tensor(p_sigma).to(gl.device)

                if u.has_nan(p_sigma):
                    s.psigma_erank = H.shape[0]
                    s.rho = 1
                else:
                    s.psigma_erank = u.sym_erank(p_sigma)
                    s.rho = H.shape[0] / s.psigma_erank

            with u.timeit(f"batch-{i}"):
                s.batch_openai = torch.trace(H @ sigma) / (g @ H @ g.t())
                print('openai batch', s.batch_openai)
                s.diversity = torch.norm(G, "fro")**2 / torch.norm(g)**2

                # s.noise_variance = torch.trace(H.inverse() @ sigma)
                # try:
                #     # this fails with singular sigma
                #     s.noise_variance = torch.trace(torch.solve(sigma, H)[0])
                #     # s.noise_variance = torch.trace(torch.lstsq(sigma, H)[0])
                #     pass
                # except RuntimeError as _:
                s.noise_variance_pinv = torch.trace(pinvH @ sigma)

                s.H_erank = torch.trace(H) / s.H_l2
                s.batch_jain_simple = 1 + s.H_erank
                s.batch_jain_full = 1 + s.rho * s.H_erank

            u.log_scalars(u.nest_stats(layer.name, s))

        # gradient steps
        last_inner = 0
        for i in range(args.train_steps):
            if last_inner:
                u.log_scalars(
                    {"time/inner": 1000 * (time.perf_counter() - last_inner)})
            last_inner = time.perf_counter()

            optimizer.zero_grad()
            data, targets = next(train_iter)
            model.zero_grad()
            output = model(data)
            loss = loss_fn(output, targets)
            loss.backward()

            u.log_scalar(train_loss=loss.item())

            if args.method != 'newton':
                optimizer.step()
            else:
                for (layer_idx, layer) in enumerate(model.layers):
                    param: torch.nn.Parameter = layer.weight
                    param_data: torch.Tensor = param.data
                    param_data.copy_(param_data - 0.1 * param.grad)
                    if layer_idx != 1:  # only update 1 layer with Newton, unstable otherwise
                        continue
                    u.nan_check(layer.weight.pre)
                    u.nan_check(param.grad.flatten())
                    u.nan_check(u.v2r(param.grad.flatten()) @ layer.weight.pre)
                    param_new_flat = u.v2r(param_data.flatten()) - u.v2r(
                        param.grad.flatten()) @ layer.weight.pre
                    u.nan_check(param_new_flat)
                    param_data.copy_(param_new_flat.reshape(param_data.shape))

            gl.token_count += data.shape[0]

    gl.event_writer.close()
Example #19
0
    def fit(self, epochs, start_epoch=0):

        try:
            for epoch in range(start_epoch, epochs):
                # Training
                for name in self.config.datasets:
                    dataset_config = AttrDefault(lambda: None,
                                                 self.config.datasets[name])
                    if dataset_config.training:
                        if dataset_config.frequency and (
                            (epoch + 1) % dataset_config.frequency):
                            continue
                        self.train(epoch, name, dataset_config)
                    # notify the model that training done
                    epoch_done_op = getattr(self.bare_model, "epoch_done",
                                            None)
                    if callable(epoch_done_op):
                        epoch_done_op(epoch)
                if self.use_swa and (epoch + 1) >= self.use_swa and (
                        epoch + 1 - self.use_swa) % self.swa_c_epochs == 0:
                    swa_moving_average(self.swa_model, self.bare_model,
                                       1.0 / (self.swa_n + 1))
                    self.swa_n += 1
                    if not self.config["swa_no_bn_update"]:
                        bn_update(self.data_loaders['training'],
                                  self.swa_model)
                    self.state['swa_state_dict'] = self.swa_model.state_dict()
                    self.state['swa_n'] = self.swa_n
                    #self.run.info['swa_n'] = self.swa_n
                    self.save_model(epoch)
                    # Testing
                    swa_testing_result = {}
                    for name in self.config.datasets:
                        dataset_config = AttrDefault(
                            lambda: None, self.config.datasets[name])
                        if dataset_config.testing:
                            swa_testing_result[name] = self.test(
                                epoch,
                                name,
                                dataset_config,
                                model=self.swa_model,
                                extra_name="_swa")

                # Testing
                testing_result = {}
                for name in self.config.datasets:
                    dataset_config = AttrDefault(lambda: None,
                                                 self.config.datasets[name])
                    if dataset_config.testing:
                        testing_result[name] = self.test(
                            epoch, name, dataset_config)

                # updating the state with new results
                self.update_state(testing_result, epoch)

                #self.run.info['epoch'] = epoch
                self.eventAfterEpoch(self, epoch)

                if shared_globals.current_learning_rate < self.min_lr:
                    shared_globals.console.info(
                        "learning rate reached minimum {} ({}), stopping in epoch {}"
                        .format(self.min_lr,
                                shared_globals.current_learning_rate, epoch))
                    break

        except KeyboardInterrupt:
            pass
        shared_globals.console.info("last test:\n" +
                                    str(self.state['metrics']))
Example #20
0
def main():
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', type=int, default=10, metavar='N',
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model', action='store_true', default=False,
                        help='For Saving the current Model')

    parser.add_argument('--wandb', type=int, default=0, help='log to weights and biases')
    parser.add_argument('--autograd_check', type=int, default=0, help='autograd correctness checks')
    parser.add_argument('--logdir', type=str, default='/tmp/runs/curv_train_tiny/run')

    parser.add_argument('--nonlin', type=int, default=1, help="whether to add ReLU nonlinearity between layers")
    parser.add_argument('--bias', type=int, default=1, help="whether to add bias between layers")

    parser.add_argument('--layer', type=int, default=-1, help="restrict updates to this layer")
    parser.add_argument('--data_width', type=int, default=28)
    parser.add_argument('--targets_width', type=int, default=28)
    parser.add_argument('--hess_samples', type=int, default=1,
                        help='number of samples when sub-sampling outputs, 0 for exact hessian')
    parser.add_argument('--hess_kfac', type=int, default=0, help='whether to use KFAC approximation for hessian')
    parser.add_argument('--compute_rho', type=int, default=0, help='use expensive method to compute rho')
    parser.add_argument('--skip_stats', type=int, default=0, help='skip all stats collection')

    parser.add_argument('--dataset_size', type=int, default=60000)
    parser.add_argument('--train_steps', type=int, default=100, help="this many train steps between stat collection")
    parser.add_argument('--stats_steps', type=int, default=1000000, help="total number of curvature stats collections")

    parser.add_argument('--full_batch', type=int, default=0, help='do stats on the whole dataset')
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--weight_decay', type=float, default=0)
    parser.add_argument('--momentum', type=float, default=0.9)
    parser.add_argument('--dropout', type=int, default=0)
    parser.add_argument('--swa', type=int, default=0)
    parser.add_argument('--lmb', type=float, default=1e-3)

    parser.add_argument('--train_batch_size', type=int, default=64)
    parser.add_argument('--stats_batch_size', type=int, default=10000)
    parser.add_argument('--stats_num_batches', type=int, default=1)
    parser.add_argument('--run_name', type=str, default='noname')
    parser.add_argument('--launch_blocking', type=int, default=0)
    parser.add_argument('--sampled', type=int, default=0)
    parser.add_argument('--curv', type=str, default='kfac',
                        help='decomposition to use for curvature estimates: zero_order, kfac, isserlis or full')
    parser.add_argument('--log_spectra', type=int, default=0)

    u.seed_random(1)
    gl.args = parser.parse_args()
    args = gl.args
    u.seed_random(1)

    gl.project_name = 'train_ciresan'
    u.setup_logdir_and_event_writer(args.run_name)
    print(f"Logging to {gl.logdir}")

    d1 = 28 * 28
    d = [784, 2500, 2000, 1500, 1000, 500, 10]

    # number of samples per datapoint. Used to normalize kfac
    model = u.SimpleFullyConnected2(d, nonlin=args.nonlin, bias=args.bias, dropout=args.dropout)
    model = model.to(gl.device)
    autograd_lib.register(model)

    assert args.dataset_size >= args.stats_batch_size
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
    dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, original_targets=True,
                          dataset_size=args.dataset_size)
    train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True, drop_last=True)
    train_iter = u.infinite_iter(train_loader)

    assert not args.full_batch, "fixme: validation still uses stats_iter"
    if not args.full_batch:
        stats_loader = torch.utils.data.DataLoader(dataset, batch_size=args.stats_batch_size, shuffle=True,
                                                   drop_last=True)
        stats_iter = u.infinite_iter(stats_loader)
    else:
        stats_iter = None

    test_dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, train=False,
                               original_targets=True,
                               dataset_size=args.dataset_size)
    test_eval_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.stats_batch_size, shuffle=False,
                                                   drop_last=False)
    train_eval_loader = torch.utils.data.DataLoader(dataset, batch_size=args.stats_batch_size, shuffle=False,
                                                    drop_last=False)

    loss_fn = torch.nn.CrossEntropyLoss()
    autograd_lib.add_hooks(model)
    autograd_lib.disable_hooks()

    gl.token_count = 0
    last_outer = 0

    for step in range(args.stats_steps):
        epoch = gl.token_count // 60000
        lr = optimizer.param_groups[0]['lr']
        print('token_count', gl.token_count)
        if last_outer:
            u.log_scalars({"time/outer": 1000 * (time.perf_counter() - last_outer)})
            print(f'time: {time.perf_counter() - last_outer:.2f}')
        last_outer = time.perf_counter()

        with u.timeit("validate"):
            val_accuracy, val_loss = validate(model, test_eval_loader, f'test (epoch {epoch})')
            train_accuracy, train_loss = validate(model, train_eval_loader, f'train (epoch {epoch})')

        # save log
        metrics = {'epoch': epoch, 'val_accuracy': val_accuracy, 'val_loss': val_loss,
                   'train_loss': train_loss, 'train_accuracy': train_accuracy,
                   'lr': optimizer.param_groups[0]['lr'],
                   'momentum': optimizer.param_groups[0].get('momentum', 0)}
        u.log_scalars(metrics)

        def mom_update(buffer, val):
            buffer *= 0.9
            buffer += val * 0.1

        if not args.skip_stats:
            # number of samples passed through
            n = args.stats_batch_size * args.stats_num_batches

            # quanti
            forward_stats = defaultdict(lambda: AttrDefault(float))

            hessians = defaultdict(lambda: AttrDefault(float))
            jacobians = defaultdict(lambda: AttrDefault(float))
            fishers = defaultdict(lambda: AttrDefault(float))  # empirical fisher/gradient
            quad_fishers = defaultdict(lambda: AttrDefault(float))  # gradient statistics that depend on fisher (4th order moments)
            train_regrets = defaultdict(list)
            test_regrets1 = defaultdict(list)
            test_regrets2 = defaultdict(list)
            train_regrets_opt = defaultdict(list)
            test_regrets_opt = defaultdict(list)
            cosines = defaultdict(list)
            dot_products = defaultdict(list)
            hessians_histograms = defaultdict(lambda: AttrDefault(u.MyList))
            jacobians_histograms = defaultdict(lambda: AttrDefault(u.MyList))
            fishers_histograms = defaultdict(lambda: AttrDefault(u.MyList))
            quad_fishers_histograms = defaultdict(lambda: AttrDefault(u.MyList))

            current = None
            current_histograms = None

            for i in range(args.stats_num_batches):
                activations = {}
                backprops = {}

                def save_activations(layer, A, _):
                    activations[layer] = A
                    forward_stats[layer].AA += torch.einsum("ni,nj->ij", A, A)

                print('forward')
                with u.timeit("stats_forward"):
                    with autograd_lib.module_hook(save_activations):
                        data, targets = next(stats_iter)
                        output = model(data)
                        loss = loss_fn(output, targets) * len(output)

                def compute_stats(layer, _, B):
                    A = activations[layer]
                    if current == fishers:
                        backprops[layer] = B

                    # about 27ms per layer
                    with u.timeit('compute_stats'):
                        current[layer].BB += torch.einsum("ni,nj->ij", B, B)  # TODO(y): index consistency
                        current[layer].diag += torch.einsum("ni,nj->ij", B * B, A * A)
                        current[layer].BA += torch.einsum("ni,nj->ij", B, A)
                        current[layer].a += torch.einsum("ni->i", A)
                        current[layer].b += torch.einsum("nk->k", B)
                        current[layer].norm2 += ((A * A).sum(dim=1) * (B * B).sum(dim=1)).sum()

                        # compute curvatures in direction of all gradiennts
                        if current is fishers:
                            assert args.stats_num_batches == 1, "not tested on more than one stats step, currently reusing aggregated moments"
                            hess = hessians[layer]
                            jac = jacobians[layer]
                            Bh, Ah = B @ hess.BB / n, A @ forward_stats[layer].AA / n
                            Bj, Aj = B @ jac.BB / n, A @ forward_stats[layer].AA / n
                            norms = ((A * A).sum(dim=1) * (B * B).sum(dim=1))

                            current[layer].min_norm2 = min(norms)
                            current[layer].median_norm2 = torch.median(norms)
                            current[layer].max_norm2 = max(norms)

                            norms2_hess = ((Ah * A).sum(dim=1) * (Bh * B).sum(dim=1))
                            norms2_jac = ((Aj * A).sum(dim=1) * (Bj * B).sum(dim=1))

                            current[layer].norm += norms.sum()
                            current_histograms[layer].norms.extend(torch.sqrt(norms))
                            current[layer].curv_hess += (skip_nans(norms2_hess / norms)).sum()
                            current_histograms[layer].curv_hess.extend(skip_nans(norms2_hess / norms))
                            current[layer].curv_hess_max += (skip_nans(norms2_hess / norms)).max()
                            current[layer].curv_hess_median += (skip_nans(norms2_hess / norms)).median()

                            current_histograms[layer].curv_jac.extend(skip_nans(norms2_jac / norms))
                            current[layer].curv_jac += (skip_nans(norms2_jac / norms)).sum()
                            current[layer].curv_jac_max += (skip_nans(norms2_jac / norms)).max()
                            current[layer].curv_jac_median += (skip_nans(norms2_jac / norms)).median()

                            current[layer].a_sparsity += torch.sum(A <= 0).float() / A.numel()
                            current[layer].b_sparsity += torch.sum(B <= 0).float() / B.numel()

                            current[layer].mean_activation += torch.mean(A)
                            current[layer].mean_activation2 += torch.mean(A*A)
                            current[layer].mean_backprop = torch.mean(B)
                            current[layer].mean_backprop2 = torch.mean(B*B)

                            current[layer].norms_hess += torch.sqrt(norms2_hess).sum()
                            current_histograms[layer].norms_hess.extend(torch.sqrt(norms2_hess))
                            current[layer].norms_jac += norms2_jac.sum()
                            current_histograms[layer].norms_jac.extend(torch.sqrt(norms2_jac))

                            normalized_moments = copy.copy(hessians[layer])
                            normalized_moments.AA = forward_stats[layer].AA
                            normalized_moments = u.divide_attributes(normalized_moments, n)

                            train_regrets_ = autograd_lib.offset_losses(A, B, alpha=lr, offset=0, m=normalized_moments,
                                                                        approx=args.curv)
                            test_regrets1_ = autograd_lib.offset_losses(A, B, alpha=lr, offset=1, m=normalized_moments,
                                                                        approx=args.curv)
                            test_regrets2_ = autograd_lib.offset_losses(A, B, alpha=lr, offset=2, m=normalized_moments,
                                                                        approx=args.curv)
                            test_regrets_opt_ = autograd_lib.offset_losses(A, B, alpha=None, offset=2,
                                                                           m=normalized_moments, approx=args.curv)
                            train_regrets_opt_ = autograd_lib.offset_losses(A, B, alpha=None, offset=0,
                                                                            m=normalized_moments, approx=args.curv)
                            cosines_ = autograd_lib.offset_cosines(A, B)
                            train_regrets[layer].extend(train_regrets_)
                            test_regrets1[layer].extend(test_regrets1_)
                            test_regrets2[layer].extend(test_regrets2_)
                            train_regrets_opt[layer].extend(train_regrets_opt_)
                            test_regrets_opt[layer].extend(test_regrets_opt_)
                            cosines[layer].extend(cosines_)
                            dot_products[layer].extend(autograd_lib.offset_dotprod(A, B))

                        # statistics of the form g.Sigma.g
                        elif current == quad_fishers:
                            hess = hessians[layer]
                            sigma = fishers[layer]
                            jac = jacobians[layer]
                            Bs, As = B @ sigma.BB / n, A @ forward_stats[layer].AA / n
                            Bh, Ah = B @ hess.BB / n, A @ forward_stats[layer].AA / n
                            Bj, Aj = B @ jac.BB / n, A @ forward_stats[layer].AA / n

                            norms = ((A * A).sum(dim=1) * (B * B).sum(dim=1))
                            norms2_hess = ((Ah * A).sum(dim=1) * (Bh * B).sum(dim=1))
                            norms2_jac = ((Aj * A).sum(dim=1) * (Bj * B).sum(dim=1))
                            norms_sigma = ((As * A).sum(dim=1) * (Bs * B).sum(dim=1))

                            current[layer].norm += norms.sum()  # TODO(y) remove, redundant with norm2 above
                            current[layer].curv_sigma += (skip_nans(norms_sigma / norms)).sum()
                            current[layer].curv_sigma_max = skip_nans(norms_sigma / norms).max()
                            current[layer].curv_sigma_median = skip_nans(norms_sigma / norms).median()
                            current[layer].curv_hess += skip_nans(norms2_hess / norms).sum()
                            current[layer].curv_hess_max += skip_nans(norms2_hess / norms).max()
                            current[layer].lyap_hess_mean += skip_nans(norms_sigma / norms2_hess).mean()
                            current[layer].lyap_hess_max = max(skip_nans(norms_sigma/norms2_hess))
                            current[layer].lyap_jac_mean += skip_nans(norms_sigma / norms2_jac).mean()
                            current[layer].lyap_jac_max = max(skip_nans(norms_sigma/norms2_jac))

                print('backward')
                with u.timeit("backprop_H"):
                    with autograd_lib.module_hook(compute_stats):
                        current = hessians
                        current_histograms = hessians_histograms
                        autograd_lib.backward_hessian(output, loss='CrossEntropy', sampled=args.sampled,
                                                      retain_graph=True)  # 600 ms
                        current = jacobians
                        current_histograms = jacobians_histograms
                        autograd_lib.backward_jacobian(output, sampled=args.sampled, retain_graph=True)  # 600 ms
                        current = fishers
                        current_histograms = fishers_histograms
                        model.zero_grad()
                        loss.backward(retain_graph=True)  # 60 ms
                        current = quad_fishers
                        current_histograms = quad_fishers_histograms
                        model.zero_grad()
                        loss.backward()  # 60 ms

            print('summarize')
            for (i, layer) in enumerate(model.layers):
                stats_dict = {'hessian': hessians, 'jacobian': jacobians, 'fisher': fishers}

                # evaluate stats from
                # https://app.wandb.ai/yaroslavvb/train_ciresan/runs/425pu650?workspace=user-yaroslavvb
                for stats_name in stats_dict:
                    s = AttrDict()
                    stats = stats_dict[stats_name][layer]

                    for key in forward_stats[layer]:
                        # print(f'copying {key} in {stats_name}, {layer}')
                        try:
                            assert stats[key] == float()
                        except:
                            f"Trying to overwrite {key} in {stats_name}, {layer}"
                        stats[key] = forward_stats[layer][key]

                    diag: torch.Tensor = stats.diag / n

                    # jacobian:
                    # curv in direction of gradient goes down to roughly 0.3-1
                    # maximum curvature goes up to 1000-2000
                    #
                    # Hessian:
                    # max curv goes down to 1, in direction of gradient 0.0001

                    s.diag_l2 = torch.max(diag)  # 40 - 3000 smaller than kfac l2 for jac
                    s.diag_fro = torch.norm(
                        diag)  # jacobian grows to 0.5-1.5, rest falls, layer-5 has phase transition, layer-4 also
                    s.diag_trace = diag.sum()  # jacobian grows 0-1000 (first), 0-150 (last). Almost same as kfac_trace (771 vs 810 kfac). Jacobian has up/down phase transition
                    s.diag_average = diag.mean()

                    # normalize for mean loss
                    BB = stats.BB / n
                    AA = stats.AA / n
                    # A_evals, _ = torch.symeig(AA)   # averaging 120ms per hit, 90 hits
                    # B_evals, _ = torch.symeig(BB)

                    # s.kfac_l2 = torch.max(A_evals) * torch.max(B_evals)    # 60x larger than diag_l2. layer0/hess has down/up phase transition. layer5/jacobian has up/down phase transition
                    s.kfac_trace = torch.trace(AA) * torch.trace(BB)  # 0/hess down/up tr, 5/jac sharp phase transition
                    s.kfac_fro = torch.norm(stats.AA) * torch.norm(
                        stats.BB)  # 0/hess has down/up tr, 5/jac up/down transition
                    # s.kfac_erank = s.kfac_trace / s.kfac_l2   # first layer has 25, rest 15, all layers go down except last, last noisy
                    # s.kfac_erank_fro = s.kfac_trace / s.kfac_fro / max(stats.BA.shape)

                    s.diversity = (stats.norm2 / n) / u.norm_squared(
                        stats.BA / n)  # gradient diversity. Goes up 3x. Bottom layer has most diversity. Jacobian diversity much less noisy than everythingelse

                    # discrepancy of KFAC based on exact values of diagonal approximation
                    # average difference normalized by average diagonal magnitude
                    diag_kfac = torch.einsum('ll,ii->li', BB, AA)
                    s.kfac_error = (torch.abs(diag_kfac - diag)).mean() / torch.mean(diag.abs())
                    u.log_scalars(u.nest_stats(f'layer-{i}/{stats_name}', s))

                # openai batch size stat
                s = AttrDict()
                hess = hessians[layer]
                jac = jacobians[layer]
                fish = fishers[layer]
                quad_fish = quad_fishers[layer]

                # the following check passes, but is expensive
                # if args.stats_num_batches == 1:
                #    u.check_close(fisher[layer].BA, layer.weight.grad)

                def trsum(A, B):
                    return (A * B).sum()  # computes tr(AB')

                grad = fishers[layer].BA / n
                s.grad_fro = torch.norm(grad)

                # get norms
                s.lyap_hess_max = quad_fish.lyap_hess_max
                s.lyap_hess_ave = quad_fish.lyap_hess_sum / n
                s.lyap_jac_max = quad_fish.lyap_jac_max
                s.lyap_jac_ave = quad_fish.lyap_jac_sum / n
                s.hess_trace = hess.diag.sum() / n
                s.jac_trace = jac.diag.sum() / n

                # Version 1 of Jain stochastic rates, use Hessian for curvature
                b = args.train_batch_size

                s.hess_curv = trsum((hess.BB / n) @ grad @ (hess.AA / n), grad) / trsum(grad, grad)
                s.jac_curv = trsum((jac.BB / n) @ grad @ (jac.AA / n), grad) / trsum(grad, grad)

                # compute gradient noise statistics
                # fish.BB has /n factor twice, hence don't need extra /n on fish.AA
                # after sampling, hess_noise,jac_noise became 100x smaller, but normalized is unaffected
                s.hess_noise = (trsum(hess.AA / n, fish.AA / n) * trsum(hess.BB / n, fish.BB / n))
                s.jac_noise = (trsum(jac.AA / n, fish.AA / n) * trsum(jac.BB / n, fish.BB / n))
                s.hess_noise_centered = s.hess_noise - trsum(hess.BB / n @ grad, grad @ hess.AA / n)
                s.jac_noise_centered = s.jac_noise - trsum(jac.BB / n @ grad, grad @ jac.AA / n)
                s.openai_gradient_noise = (fish.norms_hess / n) / trsum(hess.BB / n @ grad, grad @ hess.AA / n)

                s.mean_norm = torch.sqrt(fish.norm2) / n
                s.min_norm = torch.sqrt(fish.min_norm2)
                s.median_norm = torch.sqrt(fish.median_norm2)
                s.max_norm = torch.sqrt(fish.max_norm2)
                s.enorms = u.norm_squared(grad)
                s.a_sparsity = fish.a_sparsity
                s.b_sparsity = fish.b_sparsity
                s.mean_activation = fish.mean_activation
                s.msr_activation = torch.sqrt(fish.mean_activation2)
                s.mean_backprop = fish.mean_backprop
                s.msr_backprop = torch.sqrt(fish.mean_backprop2)

                s.norms_centered = fish.norm2 / n - u.norm_squared(grad)
                s.norms_hess = fish.norms_hess / n
                s.norms_jac = fish.norms_jac / n

                s.hess_curv_grad = fish.curv_hess / n  # phase transition, hits minimum loss in layer 1, then starts going up. Other layers take longer to reach minimum. Decreases with depth.
                s.hess_curv_grad_max = fish.curv_hess_max   # phase transition, hits minimum loss in layer 1, then starts going up. Other layers take longer to reach minimum. Decreases with depth.
                s.hess_curv_grad_median = fish.curv_hess_median   # phase transition, hits minimum loss in layer 1, then starts going up. Other layers take longer to reach minimum. Decreases with depth.
                s.sigma_curv_grad = quad_fish.curv_sigma / n
                s.sigma_curv_grad_max = quad_fish.curv_sigma_max
                s.sigma_curv_grad_median = quad_fish.curv_sigma_median
                s.band_bottou = 0.5 * lr * s.sigma_curv_grad / s.hess_curv_grad
                s.band_bottou_stoch = 0.5 * lr * quad_fish.curv_ratio / n
                s.band_yaida = 0.25 * lr * s.mean_norm**2
                s.band_yaida_centered = 0.25 * lr * s.norms_centered

                s.jac_curv_grad = fish.curv_jac / n  # this one has much lower variance than jac_curv. Reaches peak at 10k steps, also kfac error reaches peak there. Decreases with depth except for last layer.
                s.jac_curv_grad_max = fish.curv_jac_max  # this one has much lower variance than jac_curv. Reaches peak at 10k steps, also kfac error reaches peak there. Decreases with depth except for last layer.
                s.jac_curv_grad_median = fish.curv_jac_median  # this one has much lower variance than jac_curv. Reaches peak at 10k steps, also kfac error reaches peak there. Decreases with depth except for last layer.

                # OpenAI gradient noise statistics
                s.hess_noise_normalized = s.hess_noise_centered / (fish.norms_hess / n)
                s.jac_noise_normalized = s.jac_noise / (fish.norms_jac / n)

                train_regrets_, test_regrets1_, test_regrets2_, train_regrets_opt_, test_regrets_opt_, cosines_, dot_products_ = (torch.stack(r[layer]) for r in (train_regrets, test_regrets1, test_regrets2, train_regrets_opt, test_regrets_opt, cosines, dot_products))
                s.train_regret = train_regrets_.median()  # use median because outliers make it hard to see the trend
                s.test_regret1 = test_regrets1_.median()
                s.test_regret2 = test_regrets2_.median()
                s.test_regret_opt = test_regrets_opt_.median()
                s.train_regret_opt = train_regrets_opt_.median()
                s.mean_dot_product = torch.mean(dot_products_)
                s.median_dot_product = torch.median(dot_products_)
                a = [1, 2, 3]

                s.median_cosine = cosines_.median()
                s.mean_cosine = cosines_.mean()

                # get learning rates
                L1 = s.hess_curv_grad / n
                L2 = s.jac_curv_grad / n
                diversity = (fish.norm2 / n) / u.norm_squared(grad)
                robust_diversity = (fish.norm2 / n) / fish.median_norm2
                dotprod_diversity = fish.median_norm2 / s.median_dot_product
                s.lr1 = 2 / (L1 * diversity)
                s.lr2 = 2 / (L2 * diversity)
                s.lr3 = 2 / (L2 * robust_diversity)
                s.lr4 = 2 / (L2 * dotprod_diversity)

                hess_A = u.symeig_pos_evals(hess.AA / n)
                hess_B = u.symeig_pos_evals(hess.BB / n)
                fish_A = u.symeig_pos_evals(fish.AA / n)
                fish_B = u.symeig_pos_evals(fish.BB / n)
                jac_A = u.symeig_pos_evals(jac.AA / n)
                jac_B = u.symeig_pos_evals(jac.BB / n)
                u.log_scalars({f'layer-{i}/hessA_erank': erank(hess_A)})
                u.log_scalars({f'layer-{i}/hessB_erank': erank(hess_B)})
                u.log_scalars({f'layer-{i}/fishA_erank': erank(fish_A)})
                u.log_scalars({f'layer-{i}/fishB_erank': erank(fish_B)})
                u.log_scalars({f'layer-{i}/jacA_erank': erank(jac_A)})
                u.log_scalars({f'layer-{i}/jacB_erank': erank(jac_B)})
                gl.event_writer.add_histogram(f'layer-{i}/hist_hess_eig', u.outer(hess_A, hess_B).flatten(), gl.get_global_step())
                gl.event_writer.add_histogram(f'layer-{i}/hist_fish_eig', u.outer(hess_A, hess_B).flatten(), gl.get_global_step())
                gl.event_writer.add_histogram(f'layer-{i}/hist_jac_eig', u.outer(hess_A, hess_B).flatten(), gl.get_global_step())

                s.hess_l2 = max(hess_A) * max(hess_B)
                s.jac_l2 = max(jac_A) * max(jac_B)
                s.fish_l2 = max(fish_A) * max(fish_B)
                s.hess_trace = hess.diag.sum() / n

                s.jain1_sto = 1/(s.hess_trace + 2 * s.hess_l2)
                s.jain1_det = 1/s.hess_l2

                s.jain1_lr = (1 / b) * (1/s.jain1_sto) + (b - 1) / b * (1/s.jain1_det)
                s.jain1_lr = 2 / s.jain1_lr

                s.regret_ratio = (
                            train_regrets_opt_ / test_regrets_opt_).median()  # ratio between train and test regret, large means overfitting
                u.log_scalars(u.nest_stats(f'layer-{i}', s))

                # compute stats that would let you bound rho
                if i == 0:  # only compute this once, for output layer
                    hhh = hessians[model.layers[-1]].BB / n
                    fff = fishers[model.layers[-1]].BB / n
                    d = fff.shape[0]
                    L = u.lyapunov_spectral(hhh, 2 * fff, cond=1e-8)
                    L_evals = u.symeig_pos_evals(L)
                    Lcheap = fff @ u.pinv(hhh, cond=1e-8)
                    Lcheap_evals = u.eig_real(Lcheap)

                    u.log_scalars({f'mismatch/rho': d/erank(L_evals)})
                    u.log_scalars({f'mismatch/rho_cheap': d/erank(Lcheap_evals)})
                    u.log_scalars({f'mismatch/diagonalizability': erank(L_evals)/erank(Lcheap_evals)})  # 1 means diagonalizable
                    u.log_spectrum(f'mismatch/sigma', u.symeig_pos_evals(fff), loglog=False)
                    u.log_spectrum(f'mismatch/hess', u.symeig_pos_evals(hhh), loglog=False)
                    u.log_spectrum(f'mismatch/lyapunov', L_evals, loglog=True)
                    u.log_spectrum(f'mismatch/lyapunov_cheap', Lcheap_evals, loglog=True)

                gl.event_writer.add_histogram(f'layer-{i}/hist_grad_norms', u.to_numpy(fishers_histograms[layer].norms.value()), gl.get_global_step())
                gl.event_writer.add_histogram(f'layer-{i}/hist_grad_norms_hess', u.to_numpy(fishers_histograms[layer].norms_hess.value()), gl.get_global_step())
                gl.event_writer.add_histogram(f'layer-{i}/hist_curv_jac', u.to_numpy(fishers_histograms[layer].curv_jac.value()), gl.get_global_step())
                gl.event_writer.add_histogram(f'layer-{i}/hist_curv_hess', u.to_numpy(fishers_histograms[layer].curv_hess.value()), gl.get_global_step())
                gl.event_writer.add_histogram(f'layer-{i}/hist_cosines', u.to_numpy(cosines[layer]), gl.get_global_step())

                if args.log_spectra:
                    with u.timeit('spectrum'):
                        # 2/alpha
                        # s.jain1_lr = (1 / b) * s.jain1_sto + (b - 1) / b * s.jain1_det
                        # s.jain1_lr = 1 / s.jain1_lr

                        # hess.diag_trace, jac.diag_trace

                        # Version 2 of Jain stochastic rates, use Jacobian squared for curvature
                        s.jain2_sto = s.lyap_jac_max * s.jac_trace / s.lyap_jac_ave
                        s.jain2_det = s.jac_l2
                        s.jain2_lr = (1 / b) * s.jain2_sto + (b - 1) / b * s.jain2_det
                        s.jain2_lr = 1 / s.jain2_lr

                        u.log_spectrum(f'layer-{i}/hess_A', hess_A)
                        u.log_spectrum(f'layer-{i}/hess_B', hess_B)
                        u.log_spectrum(f'layer-{i}/hess_AB', u.outer(hess_A, hess_B).flatten())
                        u.log_spectrum(f'layer-{i}/jac_A', jac_A)
                        u.log_spectrum(f'layer-{i}/jac_B', jac_B)
                        u.log_spectrum(f'layer-{i}/fish_A', fish_A)
                        u.log_spectrum(f'layer-{i}/fish_B', fish_B)

                        u.log_scalars({f'layer-{i}/trace_ratio': fish_B.sum()/hess_B.sum()})

                        L = torch.eig(u.lyapunov_spectral(hess.BB, 2*fish.BB, cond=1e-8))[0]
                        L = L[:, 0]  # extract real part
                        L = L.sort()[0]
                        L = torch.flip(L, [0])

                        L_cheap = torch.eig(fish.BB @ u.pinv(hess.BB, cond=1e-8))[0]
                        L_cheap = L_cheap[:, 0]  # extract real part
                        L_cheap = L_cheap.sort()[0]
                        L_cheap = torch.flip(L_cheap, [0])

                        d = len(hess_B)
                        u.log_spectrum(f'layer-{i}/Lyap', L)
                        u.log_spectrum(f'layer-{i}/Lyap_cheap', L_cheap)

                        u.log_scalars({f'layer-{i}/dims': d})
                        u.log_scalars({f'layer-{i}/L_erank': erank(L)})
                        u.log_scalars({f'layer-{i}/L_cheap_erank': erank(L_cheap)})

                        u.log_scalars({f'layer-{i}/rho': d/erank(L)})
                        u.log_scalars({f'layer-{i}/rho_cheap': d/erank(L_cheap)})

        model.train()
        with u.timeit('train'):
            for i in range(args.train_steps):
                optimizer.zero_grad()
                data, targets = next(train_iter)
                model.zero_grad()
                output = model(data)
                loss = loss_fn(output, targets)
                loss.backward()

                optimizer.step()
                if args.weight_decay:
                    for group in optimizer.param_groups:
                        for param in group['params']:
                            param.data.mul_(1 - args.weight_decay)

                gl.token_count += data.shape[0]

    gl.event_writer.close()
Example #21
0
    def __init__(self, config, seed=42):
        global logger
        logger = shared_globals.logger
        config = AttrDefault(lambda: None, config)

        self.config = config
        self.datasets = {}
        self.data_loaders = {}
        self.use_swa = config.use_swa
        #self.run.info['epoch'] = 0
        # set random seed
        torch.manual_seed(seed)
        np.random.seed(seed + 1)
        random.seed(seed + 2)

        self.min_lr = self.config.optim_config["min_lr"]
        if self.min_lr is None:
            self.min_lr = 0.0
        print(self.min_lr)
        # making outout dirs
        models_outputdir = os.path.join(config.out_dir, "models")
        if not os.path.exists(config.out_dir):
            os.makedirs(config.out_dir)
        if not os.path.exists(models_outputdir):
            os.makedirs(models_outputdir)
        #self.run.info['out_path'] = config.out_dir

        # init_loggers
        self.init_loggers()

        self.dataset_manager = DatasetsManager(self.config['audiodataset'])

        # init Tensor board
        if self.config.tensorboard:
            tensorboard_write_path = config.tensorboard_write_path
            if not tensorboard_write_path:
                tensorboard_write_path = self.config.out_dir.replace(
                    "out", "runs", 1)
            shared_globals.console.info("tensorboard run path: " +
                                        tensorboard_write_path)
            shared_globals.console.info("To monitor this experiment use:\n " +
                                        shared_globals.bcolors.FAIL +
                                        "tensorboard --logdir " +
                                        tensorboard_write_path +
                                        shared_globals.bcolors.ENDC)
            #self.run.info['tensorboard_path'] = tensorboard_write_path
            self.writer = SummaryWriter(tensorboard_write_path)

        # init multi gpu
        self.bare_model = load_model(config.model_config)
        if self.use_swa:
            self.swa_model = load_model(config.model_config)
            if self.config.use_gpu:
                self.swa_model.cuda()
            self.swa_n = 0
            self.swa_c_epochs = config.swa_c_epochs
            self.swa_start = config.swa_start
        if self.config.use_gpu:
            self.bare_model.cuda()

        shared_globals.console.info(
            "Trainable model parameters {}, non-trainable {} ".format(
                count_parameters(self.bare_model),
                count_parameters(self.bare_model, False)))
        # DataParallel mode
        if not config.parallel_mode:
            self.model = self.bare_model
        elif config.parallel_mode == "distributed":
            torch.distributed.init_process_group(
                backend='nccl',
                world_size=1,
                rank=0,
                init_method='file://' + config.out_dir + "/shared_file")
            self.model = torch.nn.parallel.DistributedDataParallel(
                self.bare_model)
        else:
            self.model = torch.nn.DataParallel(self.bare_model)
        # self.model.cuda()

        # if load_model

        if config.get('load_model'):
            load_model_path = config.get('load_model')
            load_model_path = os.path.expanduser(load_model_path)
            shared_globals.console.info("Loading model located at: " +
                                        load_model_path)
            checkpoint = torch.load(load_model_path)
            self.model.load_state_dict(checkpoint['state_dict'])
            if self.use_swa:
                swa_state_dict = checkpoint.get('swa_state_dict', None)
                self.swa_n = checkpoint.get('swa_n', 0)
                if (swa_state_dict
                        is not None) and not self.config.swa_model_load_same:
                    self.swa_model.load_state_dict(swa_state_dict)
                else:
                    shared_globals.console.warning(
                        "No swa_state_dict loaded! same loaded")
                    self.swa_model.load_state_dict(checkpoint['state_dict'])
                    self.swa_n = 0

        shared_globals.logger.info(str(self.model))
        shared_globals.current_learning_rate = config.optim_config['base_lr']
        self.optimizer, self.scheduler = create_optimizer(
            self.model.parameters(), config.optim_config)
        print("optimizer:", self.optimizer)
        loss_criterion_args = dict(config.loss_criterion_args)
        self.criterion = get_criterion(
            config.loss_criterion)(**loss_criterion_args)

        # init state
        inf_value = -float("inf")
        if self.config["optim_config"].get("model_selection",
                                           {}).get("select_min", False):
            inf_value = float("inf")
        self.state = {
            # 'config': self.config,
            'state_dict': None,
            'optimizer': None,
            'epoch': 0,
            'metrics': {},
            'best_metric_value': inf_value,
            'best_epoch': 0,
        }
        self.first_batch_done = False
        # init dataset loaders
        self.init_loaders()

        if config.get('load_model'):
            if not config.get("load_model_no_test_first"):
                testing_result = {}
                for name in self.config.datasets:
                    dataset_config = AttrDefault(lambda: None,
                                                 self.config.datasets[name])
                    if dataset_config.testing:
                        testing_result[name] = self.test(
                            0, name, dataset_config)

                # updating the state with new results
                self.update_state(testing_result, 0)
Example #22
0
def main():
    ncluster.set_backend('aws')

    if args.config:
        assert not args.instance_type, "specify instance_type as part of config"
        assert not args.machines, "specify number of machines as part of config"
        assert re.match('\\w+', args.config)
        assert args.config in globals(), f'no config called {args.config}'
        config = eval(args.config)

    else:  # setting config vars through command-line flags
        assert args.instance_type
        assert args.machines
        config = {'base_lr': 0.000125 * 5 / 3,
                  'local_batch_size': 96,
                  'instance_type': args.instance_type,
                  'machines': args.machines}

    config = AttrDefault(str, config)  # easier access to dictionary entries
    config.image_name = IMAGE_NAME
    config.conda_env = CONDA_ENV

    if args.conda_env:
        config.conda_env = args.conda_env
        print("Using non-standard conda env ", config.conda_env)
    if args.image_name:
        config.image_name = args.image_name
        print("Using non-standard image ", config.image_name)

    instance_info = ncluster.aws_backend.INSTANCE_INFO[config.instance_type]
    num_gpus_per_machine = instance_info['gpus']

    job = ncluster.make_job(name=args.name,
                            run_name=f"{args.name}",
                            num_tasks=config.machines,
                            image_name=config.image_name,
                            instance_type=config.instance_type,
                            spot=not args.nospot,
                            skip_setup=args.skip_setup)

    job.rsync('.')
    job.run(f'killall python || echo failed && '  # kill previous run
            f'source activate {config.conda_env} && ' +
            f'pip install -r requirements.txt')

    local_batch_size = config.local_batch_size
    base_lr = config.base_lr

    num_workers = num_gpus_per_machine * config.machines
    global_batch_size = local_batch_size * num_workers
    print("using global batch ", global_batch_size)  # 512=8*32*2*1

    # linear LR scaling (https://arxiv.org/abs/1706.02677)
    lr = base_lr * (global_batch_size / BASE_LR_BATCHSIZE)

    # worker parameters with training setup
    worker_params = {
        'seed': 1111,
        'data': 'data/wikitext-103',
        'dataset': 'wt103',
        'adaptive': True,
        'log_interval': 100,
        'eval_interval': 500,
        'max_tokens': int(1.5e9),
        'logdir': job.logdir,
        'lr': lr,
        'batch_size': local_batch_size,
        'eta_min': lr / 10,
    }
    
    worker_params.update(LARGE_ARGS if config.large else SMALL_ARGS)

    user_params = {}
    # pass through some user-provided settings that were arguments to the launcher script
    if args.checkpoint_each_epoch:
        user_params['checkpoint_each_epoch'] = args.checkpoint_each_epoch
    if config.warmup_tokens:
        user_params['warmup_tokens'] = config.warmup_tokens

    if args.checkpoint or config.checkpoint:
        user_params['checkpoint'] = util.one_of([args.checkpoint, config.checkpoint])

    if args.wiki:
        worker_params.update({
            'data': 'data/wikiextracted',
            'dataset': 'wiki',
            'dropatt': 0.1,
            'dropout': 0.1,
        })

    if args.bpe:
        worker_params.update({
            'div_val': 1,
            'bpe': True,
            'adaptive': False,
        })

    worker_params.update(user_params)

    if config.extra_worker_params:
        worker_params.update(config.extra_worker_params)

    nccl_params = _get_nccl_params()

    for i, task in enumerate(job.tasks):
        dist_params = \
            f'--nproc_per_node={num_gpus_per_machine} ' \
            f'--nnodes={config.machines} --node_rank={i} ' \
            f'--master_addr={job.tasks[0].ip} --master_port={6016}'
        cmd = f'{nccl_params} python -m torch.distributed.launch {dist_params} train.py {dict_to_args(worker_params)}'
        task.run(f'echo {cmd} > {job.logdir}/task-{i}.cmd')  # save command-line
        task.run(cmd, non_blocking=True)

    print(f"Logging to {job.logdir}")
Example #23
0
def evaluation(model, data_loader, writer, config, iter_idx,
               N_eval=None, N_seg_metrics=50):

    model.eval()
    torch.set_grad_enabled(False)

    batch_size = data_loader.batch_size

    if iter_idx == 0 or config.debug:
        num_batches = 1
        fprint(f"ITER 0 / DEBUG - eval on {num_batches} batches", True)
    elif N_eval is not None and N_eval <= len(data_loader)*batch_size:
        num_batches = int(N_eval // batch_size)
        fprint(f"N_eval = {N_eval}, eval on {num_batches} batches", True)
    else:
        num_batches = len(data_loader)
        fprint(f"Eval on all {num_batches} batches")

    start_t = time.time()
    eval_stats = AttrDefault(list, {})
    batch = None

    # Loop over loader
    for b_idx, batch in enumerate(data_loader):
        if b_idx == num_batches:
            fprint(f"Breaking from eval loop after {b_idx} batches")
            break

        if config.gpu:
            for key, val in batch.items():
                batch[key] = val.cuda()

        # Forward pass
        _, losses, stats, _, _ = model(batch['input'])

        # Track individual loss terms
        for key, val in losses.items():
            # Sum over steps if needed
            if isinstance(val, list):
                eval_stats[key].append(torch.stack(val, 1).sum(1).mean(0))
            else:
                eval_stats[key].append(val.mean(0))

        # Track ELBO
        kl_m, kl_l = torch.tensor(0), torch.tensor(0)
        if 'kl_m_k' in losses:
            kl_m = torch.stack(losses.kl_m_k, dim=1).sum(1).mean(0)
        elif 'kl_m' in losses:
            kl_m = losses.kl_m.mean(0)
        if 'kl_l_k' in losses:
            kl_l = torch.stack(losses.kl_l_k, dim=1).sum(1).mean(0)
        elif 'kl_l' in losses:
            kl_l = losses.kl_l.mean(0)
        eval_stats['elbo'].append(losses.err.mean(0) + kl_m + kl_l)

        # Track segmentation metrics metrics
        if      ('instances' in batch and 'log_m_k' in stats and
                 b_idx*batch_size < N_seg_metrics):
            # ARI
            new_ari, _ = average_ari(
                stats.log_m_k, batch['instances'])
            new_ari_fg, _ = average_ari(
                stats.log_m_k, batch['instances'], True)
            eval_stats['ari'].append(new_ari)
            eval_stats['ari_fg'].append(new_ari_fg)
            # Segmentation Covering
            iseg = torch.argmax(torch.cat(stats.log_m_k, 1), 1, True)
            msc, _ = average_segcover(batch['instances'], iseg)
            msc_fg, _ = average_segcover(batch['instances'], iseg,
                                         ignore_background=True)
            eval_stats['msc'].append(msc)
            eval_stats['msc_fg'].append(msc_fg)

    # Sum over batches
    for key, val in eval_stats.items():
        # Sanity check
        if ('ari' in key or 'msc' in key) and not config.debug and iter_idx > 0:
            assert len(val)*batch_size >= N_seg_metrics
            assert len(val)*batch_size < N_seg_metrics+batch_size
        eval_stats[key] = sum(val) / len(val)

    # Track element-wise error
    nelements = np.prod(batch['input'].shape[1:4])
    eval_stats['err_element'] = eval_stats['err'] / nelements

    # Printing
    duration = time.time() - start_t
    fprint(f'Eval duration: {duration:.1f}s, {num_batches / duration:.1f} b/s')
    eval_stats['duration'] = duration
    eval_stats['num_batches'] = num_batches
    eval_stats = dict(eval_stats)
    for key, val in eval_stats.items():
        eval_stats[key] = float(val)

    # TensorBoard logging
    if writer is not None:
        log_scalars(eval_stats, 'val', iter_idx, writer)

    model.train()
    torch.set_grad_enabled(True)

    return eval_stats
Example #24
0
def main():
    attemp_count = 0
    while os.path.exists(f"{args.logdir}{attemp_count:02d}"):
        attemp_count += 1
    logdir = f"{args.logdir}{attemp_count:02d}"

    run_name = os.path.basename(logdir)
    gl.event_writer = SummaryWriter(logdir)
    print(f"Logging to {run_name}")
    u.seed_random(1)

    try:
        # os.environ['WANDB_SILENT'] = 'true'
        if args.wandb:
            wandb.init(project='curv_train_tiny', name=run_name)
            wandb.tensorboard.patch(tensorboardX=False)
            wandb.config['train_batch'] = args.train_batch_size
            wandb.config['stats_batch'] = args.stats_batch_size
            wandb.config['method'] = args.method

    except Exception as e:
        print(f"wandb crash with {e}")

    #    data_width = 4
    #    targets_width = 2

    d1 = args.data_width**2
    d2 = 10
    d3 = args.targets_width**2
    o = d3
    n = args.stats_batch_size
    d = [d1, d2, d3]
    model = u.SimpleFullyConnected(d, nonlin=args.nonlin)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

    dataset = u.TinyMNIST(data_width=args.data_width,
                          targets_width=args.targets_width,
                          dataset_size=args.dataset_size)
    train_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.train_batch_size,
        shuffle=False,
        drop_last=True)
    train_iter = u.infinite_iter(train_loader)

    stats_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.stats_batch_size,
        shuffle=False,
        drop_last=True)
    stats_iter = u.infinite_iter(stats_loader)

    def capture_activations(module, input, _output):
        if skip_forward_hooks:
            return
        assert gl.backward_idx == 0  # no need to forward-prop on Hessian computation
        assert not hasattr(
            module, 'activations'
        ), "Seeing activations from previous forward, call util.zero_grad to clear"
        assert len(input) == 1, "this works for single input layers only"
        setattr(module, "activations", input[0].detach())

    def capture_backprops(module: nn.Module, _input, output):
        if skip_backward_hooks:
            return
        assert len(output) == 1, "this works for single variable layers only"
        if gl.backward_idx == 0:
            assert not hasattr(
                module, 'backprops'
            ), "Seeing results of previous autograd, call util.zero_grad to clear"
            setattr(module, 'backprops', [])
        assert gl.backward_idx == len(module.backprops)
        module.backprops.append(output[0])

    def save_grad(param: nn.Parameter) -> Callable[[torch.Tensor], None]:
        """Hook to save gradient into 'param.saved_grad', so it can be accessed after model.zero_grad(). Only stores gradient
        if the value has not been set, call util.zero_grad to clear it."""
        def save_grad_fn(grad):
            if not hasattr(param, 'saved_grad'):
                setattr(param, 'saved_grad', grad)

        return save_grad_fn

    for layer in model.layers:
        layer.register_forward_hook(capture_activations)
        layer.register_backward_hook(capture_backprops)
        layer.weight.register_hook(save_grad(layer.weight))

    def loss_fn(data, targets):
        err = data - targets.view(-1, data.shape[1])
        assert len(data) == len(targets)
        return torch.sum(err * err) / 2 / len(data)

    gl.token_count = 0
    for step in range(args.stats_steps):
        data, targets = next(stats_iter)
        skip_forward_hooks = False
        skip_backward_hooks = False

        # get gradient values
        gl.backward_idx = 0
        u.zero_grad(model)
        output = model(data)
        loss = loss_fn(output, targets)
        loss.backward(retain_graph=True)

        print("loss", loss.item())

        # get Hessian values
        skip_forward_hooks = True
        id_mat = torch.eye(o)

        u.log_scalars({'loss': loss.item()})

        # o = 0
        for out_idx in range(o):
            model.zero_grad()
            # backprop to get section of batch output jacobian for output at position out_idx
            output = model(
                data
            )  # opt: using autograd.grad means I don't have to zero_grad
            ei = id_mat[out_idx]
            bval = torch.stack([ei] * n)
            gl.backward_idx = out_idx + 1
            output.backward(bval)
        skip_backward_hooks = True  #

        for (i, layer) in enumerate(model.layers):
            s = AttrDefault(str, {})  # dictionary-like object for layer stats

            #############################
            # Gradient stats
            #############################
            A_t = layer.activations
            assert A_t.shape == (n, d[i])

            # add factor of n because backprop takes loss averaged over batch, while we need per-example loss
            B_t = layer.backprops[0] * n
            assert B_t.shape == (n, d[i + 1])

            G = u.khatri_rao_t(B_t, A_t)  # batch loss Jacobian
            assert G.shape == (n, d[i] * d[i + 1])
            g = G.sum(dim=0, keepdim=True) / n  # average gradient
            assert g.shape == (1, d[i] * d[i + 1])

            if args.autograd_check:
                u.check_close(B_t.t() @ A_t / n, layer.weight.saved_grad)
                u.check_close(g.reshape(d[i + 1], d[i]),
                              layer.weight.saved_grad)

            # empirical Fisher
            efisher = G.t() @ G / n
            sigma = efisher - g.t() @ g
            # u.dump(sigma, f'/tmp/sigmas/{step}-{i}')
            s.sigma_l2 = u.l2_norm(sigma)

            #############################
            # Hessian stats
            #############################
            A_t = layer.activations
            Bh_t = [layer.backprops[out_idx + 1] for out_idx in range(o)]
            Amat_t = torch.cat([A_t] * o, dim=0)
            Bmat_t = torch.cat(Bh_t, dim=0)

            assert Amat_t.shape == (n * o, d[i])
            assert Bmat_t.shape == (n * o, d[i + 1])

            Jb = u.khatri_rao_t(Bmat_t,
                                Amat_t)  # batch Jacobian, in row-vec format
            H = Jb.t() @ Jb / n
            pinvH = u.pinv(H)

            s.hess_l2 = u.l2_norm(H)
            s.invhess_l2 = u.l2_norm(pinvH)

            s.hess_fro = H.flatten().norm()
            s.invhess_fro = pinvH.flatten().norm()

            s.jacobian_l2 = u.l2_norm(Jb)
            s.grad_fro = g.flatten().norm()
            s.param_fro = layer.weight.data.flatten().norm()

            u.nan_check(H)
            if args.autograd_check:
                model.zero_grad()
                output = model(data)
                loss = loss_fn(output, targets)
                H_autograd = u.hessian(loss, layer.weight)
                H_autograd = H_autograd.reshape(d[i] * d[i + 1],
                                                d[i] * d[i + 1])
                u.check_close(H, H_autograd)

            #  u.dump(sigma, f'/tmp/sigmas/H-{step}-{i}')
            def loss_direction(dd: torch.Tensor, eps):
                """loss improvement if we take step eps in direction dd"""
                return u.to_python_scalar(eps * (dd @ g.t()) -
                                          0.5 * eps**2 * dd @ H @ dd.t())

            def curv_direction(dd: torch.Tensor):
                """Curvature in direction dd"""
                return u.to_python_scalar(dd @ H @ dd.t() /
                                          dd.flatten().norm()**2)

            s.regret_newton = u.to_python_scalar(g @ u.pinv(H) @ g.t() / 2)
            s.grad_curv = curv_direction(g)
            ndir = g @ u.pinv(H)  # newton direction
            s.newton_curv = curv_direction(ndir)
            setattr(layer.weight, 'pre',
                    u.pinv(H))  # save Newton preconditioner
            s.step_openai = 1 / s.grad_curv if s.grad_curv else 999

            s.newton_fro = ndir.flatten().norm(
            )  # frobenius norm of Newton update
            s.regret_gradient = loss_direction(g, s.step_openai)

            u.log_scalars(u.nest_stats(layer.name, s))

        # gradient steps
        for i in range(args.train_steps):
            optimizer.zero_grad()
            data, targets = next(train_iter)
            model.zero_grad()
            output = model(data)
            loss = loss_fn(output, targets)
            loss.backward()

            u.log_scalar(train_loss=loss.item())

            if args.method != 'newton':
                optimizer.step()
            else:
                for (layer_idx, layer) in enumerate(model.layers):
                    param: torch.nn.Parameter = layer.weight
                    param_data: torch.Tensor = param.data
                    param_data.copy_(param_data - 0.1 * param.grad)
                    if layer_idx != 1:  # only update 1 layer with Newton, unstable otherwise
                        continue
                    u.nan_check(layer.weight.pre)
                    u.nan_check(param.grad.flatten())
                    u.nan_check(u.v2r(param.grad.flatten()) @ layer.weight.pre)
                    param_new_flat = u.v2r(param_data.flatten()) - u.v2r(
                        param.grad.flatten()) @ layer.weight.pre
                    u.nan_check(param_new_flat)
                    param_data.copy_(param_new_flat.reshape(param_data.shape))

            gl.token_count += data.shape[0]

    gl.event_writer.close()
        targets = torch.zeros_like(data)
    err = data - targets.view(-1, data.shape[1])
    return torch.sum(err * err) / 2 / len(data)

d=1
n=1
model = simple_model(1, 5)
data = torch.ones((n, d))
targets = torch.ones((n, d))
loss_fn = least_squares

autograd_lib.register(model)

hess = defaultdict(float)
hess_diag = defaultdict(float)
hess_kfac = defaultdict(lambda: AttrDefault(float))

activations = {}
def save_activations(layer, A, _):
    activations[layer] = A

    # KFAC left factor
    hess_kfac[layer].AA += torch.einsum("ni,nj->ij", A, A)

with autograd_lib.module_hook(save_activations):
    output = model(data)
    loss = loss_fn(output, targets)

def compute_hess(layer, _, B):
    A = activations[layer]
    BA = torch.einsum("nl,ni->nli", B, A)
Example #26
0
    def fit(self, epochs, start_epoch=0):

        try:
            for epoch in range(start_epoch, epochs):
                # Training
                if self.prune_mode:
                    self.model.set_prune_flag(True)
                    ramp_up_function = None
                    if self.config.adaptive_prune_rampup_mode == "linear":
                        ramp_up_function = linear_rampup
                    if self.config.adaptive_prune_rampup_mode == "exponential":
                        ramp_up_function = customsigmoid_rampup
                    remaining_params = self.model.update_prune_weights(
                        self.real_prune_percentage * ramp_up_function(
                            epoch, self.config.adaptive_prune_rampup_len),
                        self.config.prune_mode)
                    self.writer.add_scalar("remaining_params",
                                           remaining_params, epoch)
                    print("remaining_params (epoch %d): %d" %
                          (epoch, remaining_params))

                for name in self.config.datasets:
                    dataset_config = AttrDefault(lambda: None,
                                                 self.config.datasets[name])
                    if dataset_config.training:
                        if dataset_config.frequency and (
                            (epoch + 1) % dataset_config.frequency):
                            continue
                        self.train(epoch, name, dataset_config)
                    # notify the model that training done
                    epoch_done_op = getattr(self.bare_model, "epoch_done",
                                            None)
                    if callable(epoch_done_op):
                        epoch_done_op(epoch)
                if self.use_swa and (epoch + 1) >= self.use_swa and (
                        epoch + 1 - self.use_swa) % self.swa_c_epochs == 0:
                    swa_moving_average(self.swa_model, self.bare_model,
                                       1.0 / (self.swa_n + 1))
                    self.swa_n += 1
                    if not self.config["swa_no_bn_update"]:
                        bn_update(self.data_loaders['training'],
                                  self.swa_model)
                    self.state['swa_state_dict'] = self.swa_model.state_dict()
                    self.state['swa_n'] = self.swa_n
                    #self.run.info['swa_n'] = self.swa_n
                    self.save_model(epoch)
                    # Testing
                    swa_testing_result = {}
                    for name in self.config.datasets:
                        dataset_config = AttrDefault(
                            lambda: None, self.config.datasets[name])
                        if dataset_config.testing:
                            swa_testing_result[name] = self.test(
                                epoch,
                                name,
                                dataset_config,
                                model=self.swa_model,
                                extra_name="_swa")

                # Testing
                testing_result = {}
                for name in self.config.datasets:
                    dataset_config = AttrDefault(lambda: None,
                                                 self.config.datasets[name])
                    if dataset_config.testing:
                        testing_result[name] = self.test(
                            epoch, name, dataset_config)

                # updating the state with new results
                self.update_state(testing_result, epoch)

                #self.run.info['epoch'] = epoch
                self.eventAfterEpoch(self, epoch)

                if shared_globals.current_learning_rate < self.min_lr:
                    shared_globals.console.info(
                        "learning rate reached minimum {} ({}), stopping in epoch {}"
                        .format(self.min_lr,
                                shared_globals.current_learning_rate, epoch))
                    break

        except KeyboardInterrupt:
            pass
        shared_globals.console.info("last test:\n" +
                                    str(self.state['metrics']))
        if self.prune_mode:
            shared_globals.console.info(
                "trained model parameters non-zero (before update_prune) {} ".
                format(count_non_zero_params(self.bare_model)))
            self.prepare_prune()
            shared_globals.console.info(
                "trained model parameters non-zero: {} ".format(
                    count_non_zero_params(self.bare_model)))
Example #27
0
def main():

    u.seed_random(1)
    logdir = u.create_local_logdir(args.logdir)
    run_name = os.path.basename(logdir)
    gl.event_writer = SummaryWriter(logdir)
    print(f"Logging to {run_name}")

    d1 = args.data_width ** 2
    assert args.data_width == args.targets_width
    o = d1
    n = args.stats_batch_size
    d = [d1, 30, 30, 30, 20, 30, 30, 30, d1]

    # small values for debugging
    # loss_type = 'LeastSquares'
    loss_type = 'CrossEntropy'

    args.wandb = 0
    args.stats_steps = 10
    args.train_steps = 10
    args.stats_batch_size = 10
    args.data_width = 2
    args.targets_width = 2
    args.nonlin = False
    d1 = args.data_width ** 2
    d2 = 2
    d3 = args.targets_width ** 2

    if loss_type == 'CrossEntropy':
        d3 = 10
    o = d3
    n = args.stats_batch_size
    d = [d1, d2, d3]
    dsize = max(args.train_batch_size, args.stats_batch_size)+1

    model = u.SimpleFullyConnected2(d, bias=True, nonlin=args.nonlin)
    model = model.to(gl.device)

    try:
        # os.environ['WANDB_SILENT'] = 'true'
        if args.wandb:
            wandb.init(project='curv_train_tiny', name=run_name)
            wandb.tensorboard.patch(tensorboardX=False)
            wandb.config['train_batch'] = args.train_batch_size
            wandb.config['stats_batch'] = args.stats_batch_size
            wandb.config['method'] = args.method
            wandb.config['n'] = n
    except Exception as e:
        print(f"wandb crash with {e}")

    #optimizer = torch.optim.SGD(model.parameters(), lr=0.03, momentum=0.9)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.03)  # make 10x smaller for least-squares loss
    dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, dataset_size=dsize, original_targets=True)

    train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=False, drop_last=True)
    train_iter = u.infinite_iter(train_loader)

    stats_iter = None
    if not args.full_batch:
        stats_loader = torch.utils.data.DataLoader(dataset, batch_size=args.stats_batch_size, shuffle=False, drop_last=True)
        stats_iter = u.infinite_iter(stats_loader)

    test_dataset = u.TinyMNIST(data_width=args.data_width, targets_width=args.targets_width, train=False, dataset_size=dsize, original_targets=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.train_batch_size, shuffle=False, drop_last=True)
    test_iter = u.infinite_iter(test_loader)

    if loss_type == 'LeastSquares':
        loss_fn = u.least_squares
    elif loss_type == 'CrossEntropy':
        loss_fn = nn.CrossEntropyLoss()

    autograd_lib.add_hooks(model)
    gl.token_count = 0
    last_outer = 0
    val_losses = []
    for step in range(args.stats_steps):
        if last_outer:
            u.log_scalars({"time/outer": 1000*(time.perf_counter() - last_outer)})
        last_outer = time.perf_counter()

        with u.timeit("val_loss"):
            test_data, test_targets = next(test_iter)
            test_output = model(test_data)
            val_loss = loss_fn(test_output, test_targets)
            print("val_loss", val_loss.item())
            val_losses.append(val_loss.item())
            u.log_scalar(val_loss=val_loss.item())

        # compute stats
        if args.full_batch:
            data, targets = dataset.data, dataset.targets
        else:
            data, targets = next(stats_iter)

        # Capture Hessian and gradient stats
        autograd_lib.enable_hooks()
        autograd_lib.clear_backprops(model)
        autograd_lib.clear_hess_backprops(model)
        with u.timeit("backprop_g"):
            output = model(data)
            loss = loss_fn(output, targets)
            loss.backward(retain_graph=True)
        with u.timeit("backprop_H"):
            autograd_lib.backprop_hess(output, hess_type=loss_type)
        autograd_lib.disable_hooks()   # TODO(y): use remove_hooks

        with u.timeit("compute_grad1"):
            autograd_lib.compute_grad1(model)
        with u.timeit("compute_hess"):
            autograd_lib.compute_hess(model)

        for (i, layer) in enumerate(model.layers):

            # input/output layers are unreasonably expensive if not using Kronecker factoring
            if d[i]>50 or d[i+1]>50:
                print(f'layer {i} is too big ({d[i],d[i+1]}), skipping stats')
                continue

            if args.skip_stats:
                continue

            s = AttrDefault(str, {})  # dictionary-like object for layer stats

            #############################
            # Gradient stats
            #############################
            A_t = layer.activations
            assert A_t.shape == (n, d[i])

            # add factor of n because backprop takes loss averaged over batch, while we need per-example loss
            B_t = layer.backprops_list[0] * n
            assert B_t.shape == (n, d[i + 1])

            with u.timeit(f"khatri_g-{i}"):
                G = u.khatri_rao_t(B_t, A_t)  # batch loss Jacobian
            assert G.shape == (n, d[i] * d[i + 1])
            g = G.sum(dim=0, keepdim=True) / n  # average gradient
            assert g.shape == (1, d[i] * d[i + 1])

            u.check_equal(G.reshape(layer.weight.grad1.shape), layer.weight.grad1)

            if args.autograd_check:
                u.check_close(B_t.t() @ A_t / n, layer.weight.saved_grad)
                u.check_close(g.reshape(d[i + 1], d[i]), layer.weight.saved_grad)

            s.sparsity = torch.sum(layer.output <= 0) / layer.output.numel()  # proportion of activations that are zero
            s.mean_activation = torch.mean(A_t)
            s.mean_backprop = torch.mean(B_t)

            # empirical Fisher
            with u.timeit(f'sigma-{i}'):
                efisher = G.t() @ G / n
                sigma = efisher - g.t() @ g
                s.sigma_l2 = u.sym_l2_norm(sigma)
                s.sigma_erank = torch.trace(sigma)/s.sigma_l2

            lambda_regularizer = args.lmb * torch.eye(d[i + 1]*d[i]).to(gl.device)
            H = layer.weight.hess

            with u.timeit(f"invH-{i}"):
                invH = torch.cholesky_inverse(H+lambda_regularizer)

            with u.timeit(f"H_l2-{i}"):
                s.H_l2 = u.sym_l2_norm(H)
                s.iH_l2 = u.sym_l2_norm(invH)

            with u.timeit(f"norms-{i}"):
                s.H_fro = H.flatten().norm()
                s.iH_fro = invH.flatten().norm()
                s.grad_fro = g.flatten().norm()
                s.param_fro = layer.weight.data.flatten().norm()

            u.nan_check(H)
            if args.autograd_check:
                model.zero_grad()
                output = model(data)
                loss = loss_fn(output, targets)
                H_autograd = u.hessian(loss, layer.weight)
                H_autograd = H_autograd.reshape(d[i] * d[i + 1], d[i] * d[i + 1])
                u.check_close(H, H_autograd)

            #  u.dump(sigma, f'/tmp/sigmas/H-{step}-{i}')
            def loss_direction(dd: torch.Tensor, eps):
                """loss improvement if we take step eps in direction dd"""
                return u.to_python_scalar(eps * (dd @ g.t()) - 0.5 * eps ** 2 * dd @ H @ dd.t())

            def curv_direction(dd: torch.Tensor):
                """Curvature in direction dd"""
                return u.to_python_scalar(dd @ H @ dd.t() / (dd.flatten().norm() ** 2))

            with u.timeit(f"pinvH-{i}"):
                pinvH = u.pinv(H)

            with u.timeit(f'curv-{i}'):
                s.grad_curv = curv_direction(g)
                ndir = g @ pinvH  # newton direction
                s.newton_curv = curv_direction(ndir)
                setattr(layer.weight, 'pre', pinvH)  # save Newton preconditioner
                s.step_openai = s.grad_fro**2 / s.grad_curv if s.grad_curv else 999
                s.step_max = 2 / s.H_l2
                s.step_min = torch.tensor(2) / torch.trace(H)

                s.newton_fro = ndir.flatten().norm()  # frobenius norm of Newton update
                s.regret_newton = u.to_python_scalar(g @ pinvH @ g.t() / 2)   # replace with "quadratic_form"
                s.regret_gradient = loss_direction(g, s.step_openai)

            with u.timeit(f'rho-{i}'):
                p_sigma = u.lyapunov_svd(H, sigma)
                if u.has_nan(p_sigma) and args.compute_rho:  # use expensive method
                    print('using expensive method')
                    import pdb; pdb.set_trace()
                    H0, sigma0 = u.to_numpys(H, sigma)
                    p_sigma = scipy.linalg.solve_lyapunov(H0, sigma0)
                    p_sigma = torch.tensor(p_sigma).to(gl.device)

                if u.has_nan(p_sigma):
                    # import pdb; pdb.set_trace()
                    s.psigma_erank = H.shape[0]
                    s.rho = 1
                else:
                    s.psigma_erank = u.sym_erank(p_sigma)
                    s.rho = H.shape[0] / s.psigma_erank

            with u.timeit(f"batch-{i}"):
                s.batch_openai = torch.trace(H @ sigma) / (g @ H @ g.t())
                s.diversity = torch.norm(G, "fro") ** 2 / torch.norm(g) ** 2 / n

                # Faster approaches for noise variance computation
                # s.noise_variance = torch.trace(H.inverse() @ sigma)
                # try:
                #     # this fails with singular sigma
                #     s.noise_variance = torch.trace(torch.solve(sigma, H)[0])
                #     # s.noise_variance = torch.trace(torch.lstsq(sigma, H)[0])
                #     pass
                # except RuntimeError as _:
                s.noise_variance_pinv = torch.trace(pinvH @ sigma)

                s.H_erank = torch.trace(H) / s.H_l2
                s.batch_jain_simple = 1 + s.H_erank
                s.batch_jain_full = 1 + s.rho * s.H_erank

            u.log_scalars(u.nest_stats(layer.name, s))

        # gradient steps
        with u.timeit('inner'):
            for i in range(args.train_steps):
                optimizer.zero_grad()
                data, targets = next(train_iter)
                model.zero_grad()
                output = model(data)
                loss = loss_fn(output, targets)
                loss.backward()

                #            u.log_scalar(train_loss=loss.item())

                if args.method != 'newton':
                    optimizer.step()
                    if args.weight_decay:
                        for group in optimizer.param_groups:
                            for param in group['params']:
                                param.data.mul_(1-args.weight_decay)
                else:
                    for (layer_idx, layer) in enumerate(model.layers):
                        param: torch.nn.Parameter = layer.weight
                        param_data: torch.Tensor = param.data
                        param_data.copy_(param_data - 0.1 * param.grad)
                        if layer_idx != 1:  # only update 1 layer with Newton, unstable otherwise
                            continue
                        u.nan_check(layer.weight.pre)
                        u.nan_check(param.grad.flatten())
                        u.nan_check(u.v2r(param.grad.flatten()) @ layer.weight.pre)
                        param_new_flat = u.v2r(param_data.flatten()) - u.v2r(param.grad.flatten()) @ layer.weight.pre
                        u.nan_check(param_new_flat)
                        param_data.copy_(param_new_flat.reshape(param_data.shape))

                gl.token_count += data.shape[0]

    gl.event_writer.close()
Example #28
0
    def __init__(self, config, seed=42, mixed_precision_training=False):
        global logger
        logger = shared_globals.logger
        config = AttrDefault(lambda: None, config)

        self.config = config
        self.datasets = {}
        self.data_loaders = {}
        self.use_swa = config.use_swa
        self.prune_mode = config.get("prune_mode")
        #self.run.info['epoch'] = 0
        # set random seed
        torch.manual_seed(seed)
        np.random.seed(seed + 1)
        random.seed(seed + 2)

        self.min_lr = self.config.optim_config["min_lr"]
        if self.min_lr is None:
            self.min_lr = 0.0
        print(self.min_lr)
        # making outout dirs
        models_outputdir = os.path.join(config.out_dir, "models")
        if not os.path.exists(config.out_dir):
            os.makedirs(config.out_dir)
        if not os.path.exists(models_outputdir):
            os.makedirs(models_outputdir)
        #self.run.info['out_path'] = config.out_dir
        self.colab_mode = False
        self.mixed_precision_training = mixed_precision_training
        if mixed_precision_training:
            print("\n\nUsing mixed_precision_training\n\n ")
            self.scaler = torch.cuda.amp.GradScaler()

        # init_loggers
        self.init_loggers()

        self.dataset_manager = DatasetsManager(self.config['audiodataset'])

        # init Tensor board
        if self.config.tensorboard:
            tensorboard_write_path = config.tensorboard_write_path
            if not tensorboard_write_path:
                tensorboard_write_path = self.config.out_dir.replace(
                    "out", "runs", 1)
            shared_globals.console.info("tensorboard run path: " +
                                        tensorboard_write_path)
            shared_globals.console.info("To monitor this experiment use:\n " +
                                        shared_globals.bcolors.FAIL +
                                        "tensorboard --logdir " +
                                        tensorboard_write_path +
                                        shared_globals.bcolors.ENDC)
            #self.run.info['tensorboard_path'] = tensorboard_write_path
            self.writer = SummaryWriter(tensorboard_write_path)

        # init multi gpu
        self.bare_model = load_model(config.model_config)
        print(self.bare_model)
        if self.use_swa:
            self.swa_model = load_model(config.model_config)
            if self.config.use_gpu:
                self.swa_model.cuda()
            self.swa_n = 0
            self.swa_c_epochs = config.swa_c_epochs
            self.swa_start = config.swa_start

        # print number of parameters
        print("Trainable model parameters {}, non-trainable {} ".format(
            count_parameters(self.bare_model),
            count_parameters(self.bare_model, False)))
        print("Trainable model parameters non-zero {} ".format(
            count_non_zero_params(self.bare_model)))

        # move to gpu
        if self.config.use_gpu:
            self.bare_model.cuda()

        if self.prune_mode:
            try:
                true_params = self.bare_model.get_num_true_params()
                prunable_params = self.bare_model.get_num_prunable_params()
                shared_globals.console.info(
                    "True model parameters {}, Prunable params {} ".format(
                        true_params, prunable_params))
            except AttributeError:
                raise
                true_params = prunable_params = count_parameters(
                    self.bare_model)
                shared_globals.console.info(
                    "WARNING:\n\nmodel doens't support true/prunable: True {}, Prunable params {} \n\n"
                    .format(true_params, prunable_params))
            if self.config.prune_percentage == -1:  # -1 means auto
                must_prune_params = true_params - self.config.prune_percentage_target_params
                self.real_prune_percentage = must_prune_params / prunable_params
                if self.real_prune_percentage >= 0.9999:
                    raise RuntimeError(
                        "real_prune_percentage {} >= ~ 1.".format(
                            self.real_prune_percentage))
                if self.real_prune_percentage >= 0.9:
                    print("\n\nWarning: very high real_prune_percentage\n\n",
                          self.real_prune_percentage)
                if self.real_prune_percentage < 0:
                    raise RuntimeError("real_prune_percentage {} <0.".format(
                        self.real_prune_percentage))
                    print("\nWARNING: real_prune_percentage<0: ",
                          self.real_prune_percentage, " setting to 0.1\n")
                    self.real_prune_percentage = 0.1
            else:
                self.real_prune_percentage = self.config.prune_percentage
            print("current prunning percentage=", self.real_prune_percentage)

        shared_globals.console.info(
            "\n\nTrainable model parameters {}, non-trainable {} \n\n".format(
                count_parameters(self.bare_model),
                count_parameters(self.bare_model, False)))
        # DataParallel mode
        if not config.parallel_mode:
            self.model = self.bare_model
        elif config.parallel_mode == "distributed":
            torch.distributed.init_process_group(
                backend='nccl',
                world_size=1,
                rank=0,
                init_method='file://' + config.out_dir + "/shared_file")
            self.model = torch.nn.parallel.DistributedDataParallel(
                self.bare_model)
        else:
            self.model = torch.nn.DataParallel(self.bare_model)
        # self.model.cuda()

        # if load_model

        if config.get('load_model'):
            load_model_path = config.get('load_model')
            load_model_path = os.path.expanduser(load_model_path)
            shared_globals.console.info("Loading model located at: " +
                                        load_model_path)
            checkpoint = torch.load(load_model_path)
            self.model.load_state_dict(checkpoint['state_dict'])
            if self.use_swa:
                swa_state_dict = checkpoint.get('swa_state_dict', None)
                self.swa_n = checkpoint.get('swa_n', 0)
                if (swa_state_dict
                        is not None) and not self.config.swa_model_load_same:
                    self.swa_model.load_state_dict(swa_state_dict)
                else:
                    shared_globals.console.warning(
                        "No swa_state_dict loaded! same loaded")
                    self.swa_model.load_state_dict(checkpoint['state_dict'])
                    self.swa_n = 0

        shared_globals.logger.info(str(self.model))
        shared_globals.current_learning_rate = config.optim_config['base_lr']
        self.optimizer, self.scheduler = create_optimizer(
            self.model.parameters(), config.optim_config)
        print("optimizer:", self.optimizer)
        loss_criterion_args = dict(config.loss_criterion_args)
        self.criterion = get_criterion(
            config.loss_criterion)(**loss_criterion_args)

        # init state
        inf_value = -float("inf")
        if self.config["optim_config"].get("model_selection",
                                           {}).get("select_min", False):
            inf_value = float("inf")
        self.state = {
            # 'config': self.config,
            'state_dict': None,
            'optimizer': None,
            'epoch': 0,
            'metrics': {},
            'best_metric_value': inf_value,
            'best_epoch': 0,
        }
        self.first_batch_done = False
        # init dataset loaders
        self.init_loaders()

        if config.get('load_model'):
            if not config.get("load_model_no_test_first"):
                testing_result = {}
                for name in self.config.datasets:
                    dataset_config = AttrDefault(lambda: None,
                                                 self.config.datasets[name])
                    if dataset_config.testing:
                        testing_result[name] = self.test(
                            0, name, dataset_config)

                # updating the state with new results
                self.update_state(testing_result, 0)
Example #29
0
def main():
    config = AttrDefault(lambda: None, config_defaults)

    assert args.config in globals(), f"unknown config {args.config}"
    config.update(eval(args.config))

    job = ncluster.make_job(name=args.name,
                            run_name=f"{args.name}",
                            num_tasks=config.machines,
                            image_name=config.image_name,
                            instance_type=config.instance_type,
                            spot=not args.nospot,
                            skip_setup=args.skip_setup)

    job.rsync('.')
    job.run(f'killall python || echo failed && '  # kill previous run
            f'source activate {config.conda_env} && ' +
            f'pip install -r requirements.txt')

    instance_info = ncluster.aws_backend.INSTANCE_INFO[config.instance_type]
    num_gpus_per_machine = instance_info['gpus']

    total_gpus = num_gpus_per_machine * config.machines
    global_batch_size = config.batch_size * total_gpus

    # linear LR scaling (https://arxiv.org/abs/1706.02677)
    lr = config.base_lr * (global_batch_size / BASE_LR_BATCHSIZE)

    # TODO(y): change dataset location to /data/transformer-xl-data after
    # image is cut
    # worker parameters with training setup
    worker_params = {
        'seed': 1111,
        'data': 'data/wikitext-103',
        'dataset': 'wt103',
        'adaptive': True,
        'log_interval': 100,
        'eval_interval': 1000,
        'logdir': job.logdir,
        'lr': lr,
        'fp16': True,
        'dynamic_loss_scale': True,
        'batch_size': config.batch_size,
    }

    if config.architecture == 'wt103_large':
        worker_params.update(wt103_large)
    elif config.architecture == 'wt103_base':
        worker_params.update(wt103_base)
    else:
        assert False, f"Uknown architecture {config.architecture}"

    nccl_params = f'NCCL_DEBUG=VERSION NCCL_MIN_NRINGS={config.num_rings} '

    for i, task in enumerate(job.tasks):
        dist_params = \
            f'--nproc_per_node={num_gpus_per_machine} ' \
            f'--nnodes={config.machines} --node_rank={i} ' \
            f'--master_addr={job.tasks[0].ip} --master_port={6016} '
        cmd = f'{nccl_params} python -m torch.distributed.launch {dist_params} ' \
            f'train.py {util.dict_to_args(worker_params)}'
        task.run(f'echo {cmd} > {job.logdir}/task-{i}.cmd')  # save command-line
        task.run(cmd, non_blocking=True)

    print(f"Logging to {job.logdir}")

    if args.launch_tensorboard:
        task = ncluster.make_task('tensorboard',
                                  instance_type='r5.large',
                                  image_name=args.image_name)

        task.run('source activate tensorflow_p36')
        task.run(f'tensorboard --logdir={ncluster.get_logdir_root()} --port=6006',
                 non_blocking=True)
        print(f'TensorBoard at http://{task.public_ip}:6006')
Example #30
0
def test_grad_norms():
    """Test computing gradient norms using various methods."""

    u.seed_random(1)
    # torch.set_default_dtype(torch.float64)

    data_width = 3
    batch_size = 2
    d = [data_width**2, 6, 10]
    o = d[-1]
    stats_steps = 2
    num_samples = batch_size * stats_steps  # number of samples used in computation of curvature stats

    model: u.SimpleModel = u.SimpleMLP(d, nonlin=True, bias=True)
    loss_fn = torch.nn.CrossEntropyLoss()
    autograd_lib.register(model)

    dataset = u.TinyMNIST(dataset_size=num_samples,
                          data_width=data_width,
                          original_targets=True)
    stats_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=batch_size,
                                               shuffle=False)
    stats_iter = iter(stats_loader)

    moments = defaultdict(lambda: AttrDefault(float))
    norms = defaultdict(lambda: AttrDefault(MyList))
    data_batches = []
    targets_batches = []
    for stats_step in range(stats_steps):
        data, targets = next(stats_iter)
        data_batches.append(data)
        targets_batches.append(targets)

        activations = {}

        def forward_aggregate(layer, A, _):
            activations[layer] = A
            moments[layer].AA += torch.einsum('ni,nj->ij', A, A)
            moments[layer].a += torch.einsum("ni->i", A)

        with autograd_lib.module_hook(forward_aggregate):
            output = model(data)
            loss_fn(output, targets)

        def backward_aggregate(layer, _, B):
            A = activations[layer]
            moments[layer].b += torch.einsum("nk->k", B)
            moments[layer].BA += torch.einsum("nl,ni->li", B, A)
            moments[layer].BB += torch.einsum("nk,nl->kl", B, B)
            moments[layer].BABA += torch.einsum('nl,ni,nk,nj->likj', B, A, B,
                                                A)

        with autograd_lib.module_hook(backward_aggregate):
            autograd_lib.backward_hessian(output,
                                          loss='CrossEntropy',
                                          retain_graph=True)

    # compare against results using autograd
    data = torch.cat(data_batches)
    targets = torch.cat(targets_batches)

    with autograd_lib.save_activations2() as activations:
        loss = loss_fn(model(data), targets)

    def normalize_moments(d, n):
        result = AttrDict()
        for val in d:
            if type(d[val]) == torch.Tensor:
                result[val] = d[val] / n
        return result

    def compute_norms(layer, _, B):
        A = activations[layer]
        for kind in ('zero_order', 'kfac', 'isserlis', 'full'):
            normalized_moments = normalize_moments(moments[layer], num_samples)
            norms_list = getattr(norms[layer], kind)
            norms_list.extend(
                autograd_lib.grad_norms(A, B, normalized_moments, approx=kind))

    with autograd_lib.module_hook(compute_norms):
        model.zero_grad()
        (len(data) * loss).backward(retain_graph=True)

        print(norms[model.layers[0]].zero_order.value())

    for layer in model.layers:
        output = model(data)
        losses = torch.stack([
            loss_fn(output[i:i + 1], targets[i:i + 1])
            for i in range(len(data))
        ])
        grads = u.jacobian(losses, layer.weight)
        grad_norms = torch.einsum('nij,nij->n', grads, grads)
        u.check_close(grad_norms, norms[layer].zero_order)

        # test gradient norms with custom metric
        kfac_norms, isserlis_norms, full_norms = [
            u.to_pytorch(getattr(norms[layer], k))
            for k in ('kfac', 'isserlis', 'full')
        ]
        error_kfac = max(abs(kfac_norms - full_norms))
        error_isserlis = max(abs(isserlis_norms - full_norms))
        assert error_isserlis < 1e-4
        assert error_kfac < 1e-4