# rank = [None, None, (64, 64), None, (96, 96), (96, 96), None, # (128, 128), (128, 128), None, (128, 128), (128, 128), None, # #None, None, None, None, None, None, None, None, None, None, None, None, None, # None, None, None] # #((32, 32), (32, 32)), ((32, 32), (32, 32)), ((32, 32), (10,))] # rank = [None, None, (128, 128), None, (256, 256), (256, 256), None, # (512, 512), (512, 512), None, (512, 512), (512, 512), None, # #None, None, None, None, None, None, None, None, None, None, None, None, None, # #None, None, None] # ((64, 64), (64, 64)), ((64, 64), (64, 64)), ((64, 64), (10,))] model = vgg16_bn(rank).to(device) with open("../models/cifar_tensor_reg.pth", 'rb') as f: model.load_state_dict(torch.load(f, map_location=device)) # model = models.vgg11_bn().to(device) sampler = hmcsampler(model.parameters(), samples_dir="../models/cifar_tensor_reg_samples") epoch = 0 while (len(sampler.samples) < 200): epoch += 1 model.train() with tqdm(total=len(train_loader.dataset), desc='Iter {}'.format(epoch)) as bar: for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) sampler.zero_grad() output = model(data) loss = F.cross_entropy(output, target) loss += model.regularizer() / 1e4 loss /= args.temperature
ths = 0.05 ind = np.where(state_dict['lamb'].cpu().numpy() > ths)[0] rank = len(ind) state_dict['lamb'] = state_dict['lamb'][ind] for i in range(3): state_dict['factors.{}'.format(i)] = state_dict['factors.{}'.format(i)][:, ind] print('rank={}'.format(rank)) model = cp(size, rank) #model.to(device) model.load_state_dict(state_dict) sampler = hmcsampler([{'params':model.factors, 'max_length': 0.001}, # 'max_length': 1e-2 {'params':model.lamb, 'mass': 1e2, 'max_length': 0.01},#'mass': 1e2, 'max_length': 1e-2 {'params':model.tau, 'mass': 1}], #'mass': 1e2 frac_adj= 1, max_step_size=0.01) while (len(sampler.samples) <1000): loss_train = 0 for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) sampler.zero_grad() out = model(data) loss = criterion(out, target) loss_train += loss.item() # loss *= len(train_loader.dataset) ## loss *= len(data) # loss *= torch.exp(model.tau) # loss -= 0.5 * len(train_loader.dataset) * model.tau loss = regulized_loss(loss, model, len(train_loader.dataset))
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {} train_loader = torch.utils.data.DataLoader( datasets.CIFAR10('../data', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=args.batch_size, shuffle=True, **kwargs) test_loader = torch.utils.data.DataLoader( datasets.CIFAR10('../data', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=args.test_batch_size, shuffle=True, **kwargs) # model = Net().to(device) model = models.vgg11_bn().to(device) # optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) optimizer = hmcsampler(model.parameters()) for epoch in range(1, args.epochs + 1): train(args, model, device, train_loader, optimizer, epoch) test(args, model, device, test_loader) if (args.save_model): torch.save(model.state_dict(),"mnist_cnn.pth")
model.train() return x # pbar = tqdm(total=args.num_samples) # def pbar_update(correct): # total = len(test_loader.dataset) # pbar.set_postfix_str('Test set: Accuracy: {}/{} ({:.0f}%)'.format( # correct, total, # 100. * correct / total), refresh=False) # pbar.update() modelsaver = modelsaver_test(forwardfcn, test_loader) sampler = hmcsampler(model.parameters(), sampler=modelsaver, max_length=1e-2) while (len(sampler.samples) < args.num_samples): # model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) sampler.zero_grad() output = model(data) loss = criterian(output, target) * len(train_loader.dataset) loss += model.regularizer() loss.backward() sampler.step() # test(model, test_loader)
# model = Net().to(device) model = models.vgg16_bn(num_classes=10).to(device) with open("../models/cifar_vgg16.pth", 'rb') as f: model.load_state_dict(torch.load(f, map_location=device)) def forwardfcn(x): model.eval() with torch.no_grad(): x = x.to(device) x = model(x) x = x.cpu() model.train() return x sampler = hmcsampler(model.parameters(), sampler=modelsaver_test(forwardfcn, test_loader)) epoch = 0 while (len(sampler.samples) < 200): epoch += 1 model.train() bar = tqdm(total=len(train_loader.dataset), desc='Iter {}'.format(epoch)) for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) sampler.zero_grad() output = model(data) loss = F.cross_entropy(output, target) loss *= len(train_loader.dataset) * 10 loss.backward() sampler.step()