def init_classifier(dataset): """ Initialize a classifier based on the dataset. """ d = data.to_dataset(dataset) if dataset == 'mnist': return lenet.LeNet5()
def init_generator(dataset): """ Initialize a generator based on the dataset. """ d = data.to_dataset(dataset) if dataset in ('mnist', 'fashion', 'svhn'): return models.ImageGenerator(d.ny, d.nc) else: return models.DenseGenerator(d.ny, d.nx, n_layers=2)
def init_classifier(dataset): """ Initialize a classifier based on the dataset. """ d = data.to_dataset(dataset) if dataset == 'mnist': return lenet.LeNet5() elif dataset in ('svhn', 'fashion'): return resnet.ResNet(d.nc, d.ny) else: return linear.MLP(d.nx, d.ny)
def sample_random_data(dataset, num_data, dist, device): """ Sample artificial data from simple distributions. """ size = (num_data, *data.to_dataset(dataset).size) if dist == 'normal': return torch.randn(size, device=device) elif dist == 'uniform': tensor = torch.zeros(size, dtype=torch.float, device=device) tensor.uniform_(-1, 1) return tensor else: raise ValueError(dist)
def prepare_teacher(dataset): """ Prepare datasets and hyperparameters for training a teacher network. """ batch_size = 64 if dataset == 'mnist': lrn_rate = 1e-5 save_every = 10 min_epochs = 10000 val_epochs = 10000 max_epochs = 200 elif dataset == 'fashion': lrn_rate = 1e-2 save_every = 20 min_epochs = 10000 val_epochs = 10000 max_epochs = 100 elif dataset == 'svhn': lrn_rate = 1e-4 save_every = 10 min_epochs = 10000 val_epochs = 10000 max_epochs = 50 else: lrn_rate = 1e-3 save_every = 0 val_epochs = 20 min_epochs = 100 max_epochs = 1000 trn_data, val_data, test_data = data.to_dataset(dataset).to_loaders(batch_size) loss_func = nn.CrossEntropyLoss().to(DEVICE) return dict( lrn_rate=lrn_rate, save_every=save_every, min_epochs=min_epochs, val_epochs=val_epochs, max_epochs=max_epochs, trn_data=trn_data, val_data=val_data, test_data=test_data, trn_loss_func=loss_func, test_loss_func=loss_func)
def prepare_student(model, dataset, data_dist, generators=None): """ Prepare datasets and hyperparameters for training a student network. """ batch_size = 64 num_batches = 100 save_every = -1 val_epochs = 50 min_epochs = 100 if data_dist == 'kegnet': max_epochs = 400 if dataset == 'mnist': lrn_rate = 1e-5 else: lrn_rate = 1e-4 elif data_dist in ('normal', 'uniform'): max_epochs = 1000 if dataset == 'mnist': lrn_rate = 1e-6 else: lrn_rate = 1e-4 else: raise ValueError() trn_data = prepare_data( model, data_dist, dataset, batch_size, num_batches, generators) _, val_data, test_data = data.to_dataset(dataset).to_loaders(batch_size) trn_loss_func = cls_loss.KLDivLoss().to(DEVICE) test_loss_func = nn.CrossEntropyLoss().to(DEVICE) return dict( lrn_rate=lrn_rate, save_every=save_every, min_epochs=min_epochs, val_epochs=val_epochs, max_epochs=max_epochs, trn_data=trn_data, val_data=val_data, test_data=test_data, trn_loss_func=trn_loss_func, test_loss_func=test_loss_func)
def main(dataset, cls_path, out_path, index=0): """ Main function for training a generator. """ global DEVICE DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') utils.set_seed(seed=2019 + index) num_epochs = 200 save_every = 100 viz_every = 10 assert num_epochs >= save_every if dataset == 'mnist': dec_layers = 1 lrn_rate = 1e-3 alpha = 1 beta = 0 elif dataset == 'fashion': dec_layers = 3 lrn_rate = 1e-2 alpha = 1 beta = 10 elif dataset == 'svhn': dec_layers = 3 lrn_rate = 1e-2 alpha = 1 beta = 1 else: dec_layers = 2 lrn_rate = 1e-4 alpha = 1 beta = 0 cls_network = cls_utils.init_classifier(dataset).to(DEVICE) gen_network = gen_utils.init_generator(dataset).to(DEVICE) utils.load_checkpoints(cls_network, cls_path, DEVICE) nz = gen_network.num_noises nx = data.to_dataset(dataset).nx dec_network = models.Decoder(nx, nz, dec_layers).to(DEVICE) networks = (gen_network, cls_network, dec_network) path_loss = os.path.join(out_path, 'loss-gen.txt') dir_model = os.path.join(out_path, 'generator') path_model = None os.makedirs(os.path.join(out_path, 'images'), exist_ok=True) with open(path_loss, 'w') as f: f.write('Epoch\tClsLoss\tDecLoss\tDivLoss\tLossSum\tAccuracy\n') loss1 = gen_loss.ReconstructionLoss(method='kld').to(DEVICE) loss2 = gen_loss.ReconstructionLoss(method='l2').to(DEVICE) loss3 = gen_loss.DiversityLoss(metric='l1').to(DEVICE) losses = loss1, loss2, loss3 params = list(gen_network.parameters()) + list(dec_network.parameters()) optimizer = optim.Adam(params, lrn_rate) for epoch in range(1, num_epochs + 1): trn_acc, trn_losses = update(networks, losses, optimizer, alpha, beta) with open(path_loss, 'a') as f: f.write(f'{epoch:3d}') for loss in trn_losses: f.write(f'\t{loss:.8f}') f.write(f'\t{trn_acc:.8f}\n') if viz_every > 0 and epoch % viz_every == 0: path = os.path.join(out_path, f'images/images-{epoch:03d}.png') gen_utils.visualize_images(gen_network, path, DEVICE) if epoch % save_every == 0: path = f'{dir_model}-{epoch:03d}.pth.tar' utils.save_checkpoints(gen_network, path) path_model = path print(f'Finished training the generator (index={index}).') return path_model