class Solver(object): def __init__(self, config): self.model = None self.lr = config.lr self.epochs = config.epoch self.train_batch_size = config.trainBatchSize self.test_batch_size = config.testBatchSize self.criterion = None self.optimizer = None self.scheduler = None self.device = None self.cuda = config.cuda self.train_loader = None self.test_loader = None def load_data(self): train_transform = transforms.Compose( [transforms.RandomHorizontalFlip(), transforms.ToTensor()]) test_transform = transforms.Compose([transforms.ToTensor()]) train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform) self.train_loader = torch.utils.data.DataLoader( dataset=train_set, batch_size=self.train_batch_size, shuffle=True) test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform) self.test_loader = torch.utils.data.DataLoader( dataset=test_set, batch_size=self.test_batch_size, shuffle=False) def load_model(self): if self.cuda: self.device = torch.device('cuda') cudnn.benchmark = True else: self.device = torch.device('cpu') self.model = LeNet().to(self.device) # self.model = AlexNet().to(self.device) # self.model = VGG11().to(self.device) # self.model = VGG13().to(self.device) # self.model = VGG16().to(self.device) # self.model = VGG19().to(self.device) # self.model = GoogLeNet().to(self.device) # self.model = resnet18().to(self.device) # self.model = resnet34().to(self.device) # self.model = resnet50().to(self.device) # self.model = resnet101().to(self.device) # self.model = resnet152().to(self.device) # self.model = DenseNet121().to(self.device) # self.model = DenseNet161().to(self.device) # self.model = DenseNet169().to(self.device) # self.model = DenseNet201().to(self.device) # self.model = WideResNet(depth=28, num_classes=10).to(self.device) self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr) self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[75, 150], gamma=0.5) self.criterion = nn.CrossEntropyLoss().to(self.device) def train(self): print("train:") self.model.train() train_loss = 0 train_correct = 0 total = 0 for batch_num, (data, target) in enumerate(self.train_loader): data, target = data.to(self.device), target.to(self.device) self.optimizer.zero_grad() output = self.model(data) feature = self.model.feature # print('output.shape = {}, target.shape = {}, feature.shape = {}'.format(output.size(), target.size(), feature.size())) loss = self.criterion(output, target) loss.backward() self.optimizer.step() train_loss += loss.item() prediction = torch.max( output, 1) # second param "1" represents the dimension to be reduced total += target.size(0) # train_correct incremented by one if predicted right train_correct += np.sum( prediction[1].cpu().numpy() == target.cpu().numpy()) progress_bar( batch_num, len(self.train_loader), 'Loss: %.4f | Acc: %.3f%% (%d/%d)' % (train_loss / (batch_num + 1), 100. * train_correct / total, train_correct, total)) return train_loss, train_correct / total def test(self): print("test:") self.model.eval() test_loss = 0 test_correct = 0 total = 0 with torch.no_grad(): for batch_num, (data, target) in enumerate(self.test_loader): data, target = data.to(self.device), target.to(self.device) output = self.model(data) # CAM # feature = self.model.feature # print('feature: {}'.format(feature)) loss = self.criterion(output, target) test_loss += loss.item() prediction = torch.max(output, 1) total += target.size(0) test_correct += np.sum( prediction[1].cpu().numpy() == target.cpu().numpy()) progress_bar( batch_num, len(self.test_loader), 'Loss: %.4f | Acc: %.3f%% (%d/%d)' % (test_loss / (batch_num + 1), 100. * test_correct / total, test_correct, total)) return test_loss, test_correct / total def save(self): model_out_path = "model.pth" torch.save(self.model, model_out_path) print("Checkpoint saved to {}".format(model_out_path)) def run(self): self.load_data() print('Success loading data.') self.load_model() print('Success loading model.') accuracy = 0 for epoch in range(1, self.epochs + 1): self.scheduler.step(epoch) print("\n===> epoch: %d/200" % epoch) train_result = self.train() print(train_result) test_result = self.test() accuracy = max(accuracy, test_result[1]) if epoch == self.epochs: print("===> BEST ACC. PERFORMANCE: %.3f%%" % (accuracy * 100)) self.save()
def main(): # Data Loader (Input Pipeline) print('loading dataset...') train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, num_workers=args.num_workers, drop_last=False, shuffle=False) val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=batch_size, num_workers=args.num_workers, drop_last=False, shuffle=False) test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, num_workers=args.num_workers, drop_last=False, shuffle=False) # Define models print('building model...') if args.dataset == 'mnist': clf1 = LeNet() if args.dataset == 'fashionmnist': clf1 = resnet.ResNet18_F(10) if args.dataset == 'cifar10': clf1 = resnet.ResNet34(10) if args.dataset == 'svhn': clf1 = resnet.ResNet34(10) clf1.cuda() optimizer = torch.optim.SGD(clf1.parameters(), lr=args.lr, weight_decay=args.weight_decay) with open(txtfile, "a") as myfile: myfile.write('epoch train_acc val_acc test_acc\n') epoch = 0 train_acc = 0 val_acc = 0 # evaluate models with random weights test_acc = evaluate(test_loader, clf1) print('Epoch [%d/%d] Test Accuracy on the %s test data: Model1 %.4f %%' % (epoch + 1, args.n_epoch_1, len(test_dataset), test_acc)) # save results with open(txtfile, "a") as myfile: myfile.write( str(int(epoch)) + ' ' + str(train_acc) + ' ' + str(val_acc) + ' ' + str(test_acc) + ' ' + "\n") best_acc = 0.0 # training for epoch in range(1, args.n_epoch_1): # train models clf1.train() train_acc = train(clf1, train_loader, epoch, optimizer, nn.CrossEntropyLoss()) # validation val_acc = evaluate(val_loader, clf1) # evaluate models test_acc = evaluate(test_loader, clf1) # save results print( 'Epoch [%d/%d] Train Accuracy on the %s train data: Model %.4f %%' % (epoch + 1, args.n_epoch_1, len(train_dataset), train_acc)) print('Epoch [%d/%d] Val Accuracy on the %s val data: Model %.4f %% ' % (epoch + 1, args.n_epoch_1, len(val_dataset), val_acc)) print( 'Epoch [%d/%d] Test Accuracy on the %s test data: Model %.4f %% ' % (epoch + 1, args.n_epoch_1, len(test_dataset), test_acc)) with open(txtfile, "a") as myfile: myfile.write( str(int(epoch)) + ' ' + str(train_acc) + ' ' + str(val_acc) + ' ' + str(test_acc) + ' ' + "\n") if val_acc > best_acc: best_acc = val_acc torch.save(clf1.state_dict(), model_save_dir + '/' + 'model.pth') print('Matrix Factorization is doing...') clf1.load_state_dict(torch.load(model_save_dir + '/' + 'model.pth')) A = respresentations_extract(train_loader, clf1, len(train_dataset), args.dim, batch_size) A_val = respresentations_extract(val_loader, clf1, len(val_dataset), args.dim, batch_size) A_total = np.append(A, A_val, axis=0) W_total, H_total, error = train_m(A_total, args.basis, args.iteration_nmf, 1e-5) for i in range(W_total.shape[0]): for j in range(W_total.shape[1]): if W_total[i, j] < 1e-6: W_total[i, j] = 0. W = W_total[0:len(train_dataset), :] W_val = W_total[len(train_dataset):, :] print('Transition Matrix is estimating...Wating...') logits_matrix = probability_extract(train_loader, clf1, len(train_dataset), args.num_classes, batch_size) idx_matrix_group, transition_matrix_group = estimate_matrix( logits_matrix, model_save_dir) logits_matrix_val = probability_extract(val_loader, clf1, len(val_dataset), args.num_classes, batch_size) idx_matrix_group_val, transition_matrix_group_val = estimate_matrix( logits_matrix_val, model_save_dir) func = nn.MSELoss() model = Matrix_optimize(args.basis, args.num_classes) optimizer_1 = torch.optim.Adam(model.parameters(), lr=0.001) basis_matrix_group = basis_matrix_optimize(model, optimizer_1, args.basis, args.num_classes, W, transition_matrix_group, idx_matrix_group, func, model_save_dir, args.n_epoch_4) basis_matrix_group_val = basis_matrix_optimize( model, optimizer_1, args.basis, args.num_classes, W_val, transition_matrix_group_val, idx_matrix_group_val, func, model_save_dir, args.n_epoch_4) for i in range(basis_matrix_group.shape[0]): for j in range(basis_matrix_group.shape[1]): for k in range(basis_matrix_group.shape[2]): if basis_matrix_group[i, j, k] < 1e-6: basis_matrix_group[i, j, k] = 0. optimizer_ = torch.optim.SGD(clf1.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum) best_acc = 0.0 for epoch in range(1, args.n_epoch_2): # train model clf1.train() train_acc = train_correction(clf1, train_loader, epoch, optimizer_, W, basis_matrix_group, batch_size, args.num_classes, args.basis) # validation val_acc = val_correction(clf1, val_loader, epoch, W_val, basis_matrix_group_val, batch_size, args.num_classes, args.basis) # evaluate models test_acc = evaluate(test_loader, clf1) if val_acc > best_acc: best_acc = val_acc torch.save(clf1.state_dict(), model_save_dir + '/' + 'model.pth') with open(txtfile, "a") as myfile: myfile.write( str(int(epoch)) + ' ' + str(train_acc) + ' ' + str(val_acc) + ' ' + str(test_acc) + ' ' + "\n") # save results print( 'Epoch [%d/%d] Train Accuracy on the %s train data: Model %.4f %%' % (epoch + 1, args.n_epoch_2, len(train_dataset), train_acc)) print('Epoch [%d/%d] Val Accuracy on the %s val data: Model %.4f %% ' % (epoch + 1, args.n_epoch_2, len(val_dataset), val_acc)) print( 'Epoch [%d/%d] Test Accuracy on the %s test data: Model %.4f %% ' % (epoch + 1, args.n_epoch_2, len(test_dataset), test_acc)) clf1.load_state_dict(torch.load(model_save_dir + '/' + 'model.pth')) optimizer_r = torch.optim.Adam(clf1.parameters(), lr=args.lr_revision, weight_decay=args.weight_decay) nn.init.constant_(clf1.T_revision.weight, 0.0) for epoch in range(1, args.n_epoch_3): # train models clf1.train() train_acc = train_revision(clf1, train_loader, epoch, optimizer_r, W, basis_matrix_group, batch_size, args.num_classes, args.basis) # validation val_acc = val_revision(clf1, val_loader, epoch, W_val, basis_matrix_group, batch_size, args.num_classes, args.basis) # evaluate models test_acc = evaluate(test_loader, clf1) with open(txtfile, "a") as myfile: myfile.write( str(int(epoch)) + ' ' + str(train_acc) + ' ' + str(val_acc) + ' ' + str(test_acc) + ' ' + "\n") # save results print( 'Epoch [%d/%d] Train Accuracy on the %s train data: Model %.4f %%' % (epoch + 1, args.n_epoch_3, len(train_dataset), train_acc)) print('Epoch [%d/%d] Val Accuracy on the %s val data: Model %.4f %% ' % (epoch + 1, args.n_epoch_3, len(val_dataset), val_acc)) print( 'Epoch [%d/%d] Test Accuracy on the %s test data: Model %.4f %% ' % (epoch + 1, args.n_epoch_3, len(test_dataset), test_acc))
mu = args.min_mu mu_max = args.max_mu for epoch in range(1, args.epochs + 1): # Configure dropout parameters. if args.LW_dropout_perc > 0 and args.LW_dropout_delay < epoch: useDropout = args.LW_dropout_perc else: useDropout = 0 batch_size = args.batch_size + args.delta_batch_size * (epoch - 1) print('\nEpoch {} of {}. mu = {:.2f}, batch_size = {}, algorithm = {}'. format(epoch, args.epochs, mu, batch_size, algName)) model.train() for batch_idx, (data, targets) in enumerate(train_loader): data, targets = data.to(device), targets.to(device) if algName == 'altmin': #---------------------------------------------------------- # Set L1 weights according to \mu. if args.lambda_c_muFact > 0: epoch_lam_c = args.lambda_c_muFact * mu else: epoch_lam_c = args.lambda_c if args.lambda_w_muFact > 0: epoch_lam_w = args.lambda_w_muFact * mu else: epoch_lam_w = args.lambda_w
writer_loss = SummaryWriter(gen_path(loss_path)) writer_acc = SummaryWriter(gen_path(acc_path)) trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) dataset_train = MNIST('./data/mnist/', train=True, download=True, transform=trans_mnist) dataset_test = MNIST('./data/mnist/', train=False, download=True, transform=trans_mnist) # sample users dict_users = split_noniid_shuffle(dataset_train, args.num_nodes) img_size = dataset_train[0][0].shape print(img_size) net_glob = LeNet().to(args.device) print(net_glob.fc1.weight.type()) print(net_glob) net_glob.train() # copy weights w_glob = net_glob.state_dict() w_glob_grad = w_glob # training #loss_train = [] w_locals = [w_glob for i in range(args.num_nodes)] for iter in range(args.epochs): loss_locals = [] for idx in range(args.num_nodes): #import pdb; pdb.set_trace() local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
class Reptile(object): def __init__(self, args): self.args = args self._load_model() self.model.to(args.device) self.task_generator = TaskGen(args.max_num_classes) self.outer_stepsize = args.outer_stepsize self.criterion = nn.CrossEntropyLoss() # self.optimizer = optim.Adam(self.model.parameters(), lr=args.inner_stepsize) def _load_model(self): self.model = LeNet() self.current_iteration = 0 if os.path.exists(self.args.model_path): try: print("Loading model from: {}".format(self.args.model_path)) self.model.load_state_dict(torch.load(self.args.model_path)) self.current_iteration = joblib.load("{}.iter".format( self.args.model_path)) except Exception as e: print( "Exception: {}\nCould not load model from {} - starting from scratch" .format(e, self.args.model_path)) def inner_training(self, x, y, num_iterations): """ Run training on task """ x, y = shuffle_unison(x, y) self.model.train() x = torch.tensor(x, dtype=torch.float, device=self.args.device) y = torch.tensor(y, dtype=torch.float, device=self.args.device) total_loss = 0 for _ in range(num_iterations): start = np.random.randint(0, len(x) - self.args.inner_batch_size + 1) self.model.zero_grad() # self.optimizer.zero_grad() outputs = self.model(x[start:start + self.args.inner_batch_size]) # print("output: {} - y: {}".format(outputs.shape, y.shape)) loss = self.criterion( outputs, Variable(y[start:start + self.args.inner_batch_size].long())) total_loss += loss loss.backward() # self.optimizer.step() # Similar to calling optimizer.step() for param in self.model.parameters(): param.data -= self.args.inner_stepsize * param.grad.data return total_loss / self.args.inner_iterations def _meta_gradient_update(self, iteration, num_classes, weights_before): """ Interpolate between current weights and trained weights from this task I.e. (weights_before - weights_after) is the meta-gradient - iteration: current iteration - used for updating outer_stepsize - num_classes: current classifier number of classes - weights_before: state of weights before inner steps training """ weights_after = self.model.state_dict() outer_stepsize = self.outer_stepsize * ( 1 - iteration / self.args.n_iterations) # linear schedule self.model.load_state_dict({ name: weights_before[name] + (weights_after[name] - weights_before[name]) * outer_stepsize for name in weights_before }) def meta_training(self): # Reptile training loop total_loss = 0 try: while self.current_iteration < self.args.n_iterations: # Generate task data, labels, original_labels, num_classes = self.task_generator.get_train_task( args.num_classes) weights_before = deepcopy(self.model.state_dict()) loss = self.inner_training(data, labels, self.args.inner_iterations) total_loss += loss if self.current_iteration % self.args.log_every == 0: print("-----------------------------") print("iteration {}".format( self.current_iteration + 1)) print("Loss: {:.3f}".format(total_loss / (self.current_iteration + 1))) print("Current task info: ") print("\t- Number of classes: {}".format(num_classes)) print("\t- Batch size: {}".format(len(data))) print("\t- Labels: {}".format(set(original_labels))) self.test() self._meta_gradient_update(self.current_iteration, num_classes, weights_before) self.current_iteration += 1 torch.save(self.model.state_dict(), self.args.model_path) except KeyboardInterrupt: print("Manual Interrupt...") print("Saving to: {}".format(self.args.model_path)) torch.save(self.model.state_dict(), self.args.model_path) joblib.dump(self.current_iteration, "{}.iter".format(self.args.model_path), compress=1) def predict(self, x): self.model.eval() x = torch.tensor(x, dtype=torch.float, device=self.args.device) outputs = self.model(x) return outputs.cpu().data.numpy() def test(self): """ Run tests 1. Create task from test set. 2. Reload model 3. Check accuracy on test set 4. Train for one or more iterations on one task 5. Check accuracy again on test set """ test_data, test_labels, _, _ = self.task_generator.get_test_task( selected_labels=[1, 2, 3, 4, 5], num_samples=-1) # all available samples predicted_labels = np.argmax(self.predict(test_data), axis=1) accuracy = np.mean(1 * (predicted_labels == test_labels)) * 100 print( "Accuracy before few shots learning (a.k.a. zero-shot learning): {:.2f}%\n----" .format(accuracy)) weights_before = deepcopy( self.model.state_dict()) # save snapshot before evaluation for i in range(1, 5): enroll_data, enroll_labels, _, _ = self.task_generator.get_enroll_task( selected_labels=[1, 2, 3, 4, 5], num_samples=i) self.inner_training(enroll_data, enroll_labels, self.args.inner_iterations_test) predicted_labels = np.argmax(self.predict(test_data), axis=1) accuracy = np.mean(1 * (predicted_labels == test_labels)) * 100 print("Accuracy after {} shot{} learning: {:.2f}%)".format( i, "" if i == 1 else "s", accuracy)) self.model.load_state_dict(weights_before) # restore from snapshot