Пример #1
0
def train_model(args):
    use_cuda = not args["no_cuda"] and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    nchannels, nclasses = 3, 10
    if args["dataset"] == 'MNIST': nchannels = 1
    if args["dataset"] == 'CIFAR100': nclasses = 100

    ir_strength = args["init_reg_strength"]
    square_loss = args["square_loss"]

    # make hook to store activation values
    activations = []

    def get_activation(name):
        def hook(model, input, output):
            activations.append(output.detach())

        return hook

    # create an initial model
    model = nn.Sequential(nn.Linear(32 * 32 * nchannels, args["nunits"]),
                          nn.ReLU(), nn.Linear(args["nunits"], nclasses))
    model = model.to(device)

    # create a copy of the initial model to be used later
    init_model = copy.deepcopy(model)

    # register hook
    model[0].register_forward_hook(get_activation(model[1]))

    # define optimizer
    optimizer = optim.SGD(model.parameters(),
                          args["learningrate"],
                          momentum=args["momentum"],
                          weight_decay=args["weightdecay"])

    # loading data
    train_dataset = load_data('train', args["dataset"], args["datadir"])
    val_dataset = load_data('val', args["dataset"], args["datadir"])

    train_loader = DataLoader(train_dataset,
                              batch_size=args["batchsize"],
                              shuffle=True,
                              **kwargs)
    val_loader = DataLoader(val_dataset,
                            batch_size=args["batchsize"],
                            shuffle=False,
                            **kwargs)

    start_epoch = 0

    path = "saved_models/" + args["dataset"] + "/SQUARE/WD" + str(
        args["weightdecay"]) + "/N" + str(int(math.log(args["nunits"], 2)))
    if os.path.isdir(path):
        # If exact epochs dir exists, select it
        # Else find latest directory
        epoch_path = path + "/E" + str(args["epochs"])
        if os.path.isdir(epoch_path):
            latest_checkpoint = epoch_path + "/checkpoint.pth.tar"
        else:
            latest_dir = max(glob.glob(os.path.join(path, '*/')),
                             key=os.path.getmtime)
            latest_checkpoint = latest_dir + "/checkpoint.pth.tar"

        checkpoint = torch.load(latest_checkpoint)
        start_epoch = checkpoint['epoch']
        epoch = start_epoch
        optimizer.load_state_dict(checkpoint['optimizer'])
        model.load_state_dict(checkpoint['state_dict'])
        init_model.load_state_dict(checkpoint['init'])
        print("Loading checkpoint for model: " +
              str(int(math.log(args['nunits'], 2))) + " epoch " + str(epoch))

    # training the model
    for epoch in range(start_epoch, args["epochs"]):
        # train for one epoch
        tr_err, tr_loss = train(model, init_model, device, train_loader,
                                optimizer, ir_strength, square_loss)

        val_err, val_loss, val_margin = validate(model, init_model, device,
                                                 val_loader, ir_strength,
                                                 square_loss)

        print(
            'Epoch: ' + str(epoch + 1) + "/" + str(args["epochs"]) +
            '\t Training loss: ' + str(round(tr_loss, 3)) + '\t',
            'Training error: ' + str(round(tr_err, 3)) +
            '\t Validation error: ' + str(round(val_err, 3)))

        if (epoch + 1) % 50 == 0 and epoch > 0:
            path = "./saved_models/" + args["dataset"] + "/SQUARE/WD" + str(
                args["weightdecay"]) + "/N" + str(
                    int(math.log(args["nunits"], 2))) + "/E" + str(epoch + 1)
            pathlib.Path(path).mkdir(parents=True, exist_ok=True)
            torch.save(
                {
                    "state_dict": model.state_dict(),
                    "init": init_model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "epoch": (epoch + 1)
                }, path + "/checkpoint.pth.tar")

        # stop training if the cross-entropy loss is less than the stopping condition
        if tr_loss < args["stopcond"]:
            break

    tr_err, tr_loss, tr_margin = validate(model, init_model, device,
                                          train_loader, ir_strength,
                                          square_loss)
    val_err, val_loss, val_margin = validate(model, init_model, device,
                                             val_loader, ir_strength,
                                             square_loss)
    print('\nFinal: Training loss: ' + str(round(tr_loss, 3)) +
          '\t Training margin: ' + str(round(tr_margin, 3)) +
          '\t Training error: ' + str(round(tr_err, 3)) +
          '\t Validation error: ' + str(round(val_err, 3)) + '\n')

    measure = measures.calculate(model, init_model, device, train_loader,
                                 tr_margin)
    return measure, activations
Пример #2
0
def main():

    # settings
    parser = argparse.ArgumentParser(
        description='Training a fully connected NN with one hidden layer')
    parser.add_argument('--no-cuda',
                        default=False,
                        action='store_true',
                        help='disables CUDA training')
    parser.add_argument(
        '--datadir',
        default='datasets',
        type=str,
        help=
        'path to the directory that contains the datasets (default: datasets)')
    parser.add_argument(
        '--dataset',
        default='CIFAR10',
        type=str,
        help=
        'name of the dataset (options: MNIST | CIFAR10 | CIFAR100 | SVHN, default: CIFAR10)'
    )
    parser.add_argument('--nunits',
                        default=1024,
                        type=int,
                        help='number of hidden units (default: 1024)')
    parser.add_argument('--epochs',
                        default=1000,
                        type=int,
                        help='number of epochs to train (default: 1000)')
    parser.add_argument(
        '--stopcond',
        default=0.01,
        type=float,
        help='stopping condtion based on the cross-entropy loss (default: 0.01)'
    )
    parser.add_argument('--batchsize',
                        default=64,
                        type=int,
                        help='input batch size (default: 64)')
    parser.add_argument('--learningrate',
                        default=0.001,
                        type=float,
                        help='learning rate (default: 0.001)')
    parser.add_argument('--momentum',
                        default=0.9,
                        type=float,
                        help='momentum (default: 0.9)')
    args = parser.parse_args()

    use_cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    nchannels, nclasses = 3, 10
    if args.dataset == 'MNIST': nchannels = 1
    if args.dataset == 'CIFAR100': nclasses = 100

    # create an initial model
    model = nn.Sequential(nn.Linear(32 * 32 * nchannels, args.nunits),
                          nn.ReLU(), nn.Linear(args.nunits, nclasses))
    model = model.to(device)

    # create a copy of the initial model to be used later
    init_model = copy.deepcopy(model)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optim.SGD(model.parameters(),
                          args.learningrate,
                          momentum=args.momentum)

    # loading data
    train_dataset = load_data('train', args.dataset, args.datadir, nchannels)
    val_dataset = load_data('val', args.dataset, args.datadir, nchannels)

    train_loader = DataLoader(train_dataset,
                              batch_size=args.batchsize,
                              shuffle=True,
                              **kwargs)
    val_loader = DataLoader(val_dataset,
                            batch_size=args.batchsize,
                            shuffle=False,
                            **kwargs)

    # training the model
    for epoch in range(0, args.epochs):
        # train for one epoch
        tr_err, tr_loss = train(args, model, device, train_loader, criterion,
                                optimizer, epoch)

        val_err, val_loss, val_margin = validate(args, model, device,
                                                 val_loader, criterion)

        print(
            f'Epoch: {epoch + 1}/{args.epochs}\t Training loss: {tr_loss:.3f}\t',
            f'Training error: {tr_err:.3f}\t Validation error: {val_err:.3f}')

        # stop training if the cross-entropy loss is less than the stopping condition
        if tr_loss < args.stopcond: break

    # calculate the training error and margin of the learned model
    tr_err, tr_loss, tr_margin = validate(args, model, device, train_loader,
                                          criterion)
    print(
        f'\nFinal: Training loss: {tr_loss:.3f}\t Training margin {tr_margin:.3f}\t ',
        f'Training error: {tr_err:.3f}\t Validation error: {val_err:.3f}\n')

    measure = measures.calculate(model, init_model, device, train_loader,
                                 tr_margin)
    for key, value in measure.items():
        print(f'{key:s}:\t {float(value):3.3}')
Пример #3
0
    model = model.to(device)
    init_model = copy.deepcopy(model)
    optimizer = optim.SGD(model.parameters(), 0.001, momentum=0.9, weight_decay=0.001)

    checkpoint_path = "saved_models/CIFAR10/WD0.0025/N14/E" + str(epoch) + "/checkpoint.pth.tar"

    print("Loading checkpoint for model: 2^14 at epoch " + str(epoch))

    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    init_model.load_state_dict(checkpoint['init'])
    
    tr_err, tr_loss, tr_margin = main.validate(model, init_model, device, train_loader)
    val_err, val_loss, val_margin = main.validate(model, init_model, device, val_loader)
    measure = measures.calculate(model, init_model, device, train_loader, tr_margin)
    bound = list(measure.items())[-6:]
    bound = [float(bound[i][1]) for i in range(0, 6)]
    for i in range(0, 6):
        bounds[i].append(bound[i])

plt.plot(epochs, np.array(bounds[0]), marker="+", label="(1) VC-dim", color="blue")
plt.plot(epochs, np.array(bounds[1]), marker="+", label="(2) l1,max", color="orange")
plt.plot(epochs, np.array(bounds[2]), marker="+", label="(3) Fro", color="green")
plt.plot(epochs, np.array(bounds[3]), marker="+", label="(4) spec-l2,1", color="black")
plt.plot(epochs, np.array(bounds[4]), marker="+", label="(5) spec-Fro", color="brown")
plt.plot(epochs, np.array(bounds[5]), marker="+", label="(6) ours", color="red")
plt.xlabel("Epoch #")
plt.ylabel("Capacity")
plt.xticks([i for i in range(0, 600, 100)])
plt.yscale("log")