def train_model(args): use_cuda = not args["no_cuda"] and torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} nchannels, nclasses = 3, 10 if args["dataset"] == 'MNIST': nchannels = 1 if args["dataset"] == 'CIFAR100': nclasses = 100 ir_strength = args["init_reg_strength"] square_loss = args["square_loss"] # make hook to store activation values activations = [] def get_activation(name): def hook(model, input, output): activations.append(output.detach()) return hook # create an initial model model = nn.Sequential(nn.Linear(32 * 32 * nchannels, args["nunits"]), nn.ReLU(), nn.Linear(args["nunits"], nclasses)) model = model.to(device) # create a copy of the initial model to be used later init_model = copy.deepcopy(model) # register hook model[0].register_forward_hook(get_activation(model[1])) # define optimizer optimizer = optim.SGD(model.parameters(), args["learningrate"], momentum=args["momentum"], weight_decay=args["weightdecay"]) # loading data train_dataset = load_data('train', args["dataset"], args["datadir"]) val_dataset = load_data('val', args["dataset"], args["datadir"]) train_loader = DataLoader(train_dataset, batch_size=args["batchsize"], shuffle=True, **kwargs) val_loader = DataLoader(val_dataset, batch_size=args["batchsize"], shuffle=False, **kwargs) start_epoch = 0 path = "saved_models/" + args["dataset"] + "/SQUARE/WD" + str( args["weightdecay"]) + "/N" + str(int(math.log(args["nunits"], 2))) if os.path.isdir(path): # If exact epochs dir exists, select it # Else find latest directory epoch_path = path + "/E" + str(args["epochs"]) if os.path.isdir(epoch_path): latest_checkpoint = epoch_path + "/checkpoint.pth.tar" else: latest_dir = max(glob.glob(os.path.join(path, '*/')), key=os.path.getmtime) latest_checkpoint = latest_dir + "/checkpoint.pth.tar" checkpoint = torch.load(latest_checkpoint) start_epoch = checkpoint['epoch'] epoch = start_epoch optimizer.load_state_dict(checkpoint['optimizer']) model.load_state_dict(checkpoint['state_dict']) init_model.load_state_dict(checkpoint['init']) print("Loading checkpoint for model: " + str(int(math.log(args['nunits'], 2))) + " epoch " + str(epoch)) # training the model for epoch in range(start_epoch, args["epochs"]): # train for one epoch tr_err, tr_loss = train(model, init_model, device, train_loader, optimizer, ir_strength, square_loss) val_err, val_loss, val_margin = validate(model, init_model, device, val_loader, ir_strength, square_loss) print( 'Epoch: ' + str(epoch + 1) + "/" + str(args["epochs"]) + '\t Training loss: ' + str(round(tr_loss, 3)) + '\t', 'Training error: ' + str(round(tr_err, 3)) + '\t Validation error: ' + str(round(val_err, 3))) if (epoch + 1) % 50 == 0 and epoch > 0: path = "./saved_models/" + args["dataset"] + "/SQUARE/WD" + str( args["weightdecay"]) + "/N" + str( int(math.log(args["nunits"], 2))) + "/E" + str(epoch + 1) pathlib.Path(path).mkdir(parents=True, exist_ok=True) torch.save( { "state_dict": model.state_dict(), "init": init_model.state_dict(), "optimizer": optimizer.state_dict(), "epoch": (epoch + 1) }, path + "/checkpoint.pth.tar") # stop training if the cross-entropy loss is less than the stopping condition if tr_loss < args["stopcond"]: break tr_err, tr_loss, tr_margin = validate(model, init_model, device, train_loader, ir_strength, square_loss) val_err, val_loss, val_margin = validate(model, init_model, device, val_loader, ir_strength, square_loss) print('\nFinal: Training loss: ' + str(round(tr_loss, 3)) + '\t Training margin: ' + str(round(tr_margin, 3)) + '\t Training error: ' + str(round(tr_err, 3)) + '\t Validation error: ' + str(round(val_err, 3)) + '\n') measure = measures.calculate(model, init_model, device, train_loader, tr_margin) return measure, activations
def main(): # settings parser = argparse.ArgumentParser( description='Training a fully connected NN with one hidden layer') parser.add_argument('--no-cuda', default=False, action='store_true', help='disables CUDA training') parser.add_argument( '--datadir', default='datasets', type=str, help= 'path to the directory that contains the datasets (default: datasets)') parser.add_argument( '--dataset', default='CIFAR10', type=str, help= 'name of the dataset (options: MNIST | CIFAR10 | CIFAR100 | SVHN, default: CIFAR10)' ) parser.add_argument('--nunits', default=1024, type=int, help='number of hidden units (default: 1024)') parser.add_argument('--epochs', default=1000, type=int, help='number of epochs to train (default: 1000)') parser.add_argument( '--stopcond', default=0.01, type=float, help='stopping condtion based on the cross-entropy loss (default: 0.01)' ) parser.add_argument('--batchsize', default=64, type=int, help='input batch size (default: 64)') parser.add_argument('--learningrate', default=0.001, type=float, help='learning rate (default: 0.001)') parser.add_argument('--momentum', default=0.9, type=float, help='momentum (default: 0.9)') args = parser.parse_args() use_cuda = not args.no_cuda and torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} nchannels, nclasses = 3, 10 if args.dataset == 'MNIST': nchannels = 1 if args.dataset == 'CIFAR100': nclasses = 100 # create an initial model model = nn.Sequential(nn.Linear(32 * 32 * nchannels, args.nunits), nn.ReLU(), nn.Linear(args.nunits, nclasses)) model = model.to(device) # create a copy of the initial model to be used later init_model = copy.deepcopy(model) # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().to(device) optimizer = optim.SGD(model.parameters(), args.learningrate, momentum=args.momentum) # loading data train_dataset = load_data('train', args.dataset, args.datadir, nchannels) val_dataset = load_data('val', args.dataset, args.datadir, nchannels) train_loader = DataLoader(train_dataset, batch_size=args.batchsize, shuffle=True, **kwargs) val_loader = DataLoader(val_dataset, batch_size=args.batchsize, shuffle=False, **kwargs) # training the model for epoch in range(0, args.epochs): # train for one epoch tr_err, tr_loss = train(args, model, device, train_loader, criterion, optimizer, epoch) val_err, val_loss, val_margin = validate(args, model, device, val_loader, criterion) print( f'Epoch: {epoch + 1}/{args.epochs}\t Training loss: {tr_loss:.3f}\t', f'Training error: {tr_err:.3f}\t Validation error: {val_err:.3f}') # stop training if the cross-entropy loss is less than the stopping condition if tr_loss < args.stopcond: break # calculate the training error and margin of the learned model tr_err, tr_loss, tr_margin = validate(args, model, device, train_loader, criterion) print( f'\nFinal: Training loss: {tr_loss:.3f}\t Training margin {tr_margin:.3f}\t ', f'Training error: {tr_err:.3f}\t Validation error: {val_err:.3f}\n') measure = measures.calculate(model, init_model, device, train_loader, tr_margin) for key, value in measure.items(): print(f'{key:s}:\t {float(value):3.3}')
model = model.to(device) init_model = copy.deepcopy(model) optimizer = optim.SGD(model.parameters(), 0.001, momentum=0.9, weight_decay=0.001) checkpoint_path = "saved_models/CIFAR10/WD0.0025/N14/E" + str(epoch) + "/checkpoint.pth.tar" print("Loading checkpoint for model: 2^14 at epoch " + str(epoch)) checkpoint = torch.load(checkpoint_path) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) init_model.load_state_dict(checkpoint['init']) tr_err, tr_loss, tr_margin = main.validate(model, init_model, device, train_loader) val_err, val_loss, val_margin = main.validate(model, init_model, device, val_loader) measure = measures.calculate(model, init_model, device, train_loader, tr_margin) bound = list(measure.items())[-6:] bound = [float(bound[i][1]) for i in range(0, 6)] for i in range(0, 6): bounds[i].append(bound[i]) plt.plot(epochs, np.array(bounds[0]), marker="+", label="(1) VC-dim", color="blue") plt.plot(epochs, np.array(bounds[1]), marker="+", label="(2) l1,max", color="orange") plt.plot(epochs, np.array(bounds[2]), marker="+", label="(3) Fro", color="green") plt.plot(epochs, np.array(bounds[3]), marker="+", label="(4) spec-l2,1", color="black") plt.plot(epochs, np.array(bounds[4]), marker="+", label="(5) spec-Fro", color="brown") plt.plot(epochs, np.array(bounds[5]), marker="+", label="(6) ours", color="red") plt.xlabel("Epoch #") plt.ylabel("Capacity") plt.xticks([i for i in range(0, 600, 100)]) plt.yscale("log")