示例#1
0
class Solver(object):
    def __init__(self, config):
        self.model = None
        self.lr = config.lr
        self.epochs = config.epoch
        self.train_batch_size = config.trainBatchSize
        self.test_batch_size = config.testBatchSize
        self.criterion = None
        self.optimizer = None
        self.scheduler = None
        self.device = None
        self.cuda = config.cuda
        self.train_loader = None
        self.test_loader = None

    def load_data(self):
        train_transform = transforms.Compose(
            [transforms.RandomHorizontalFlip(),
             transforms.ToTensor()])
        test_transform = transforms.Compose([transforms.ToTensor()])
        train_set = torchvision.datasets.CIFAR10(root='./data',
                                                 train=True,
                                                 download=True,
                                                 transform=train_transform)
        self.train_loader = torch.utils.data.DataLoader(
            dataset=train_set, batch_size=self.train_batch_size, shuffle=True)
        test_set = torchvision.datasets.CIFAR10(root='./data',
                                                train=False,
                                                download=True,
                                                transform=test_transform)
        self.test_loader = torch.utils.data.DataLoader(
            dataset=test_set, batch_size=self.test_batch_size, shuffle=False)

    def load_model(self):
        if self.cuda:
            self.device = torch.device('cuda')
            cudnn.benchmark = True
        else:
            self.device = torch.device('cpu')

        self.model = LeNet().to(self.device)
        # self.model = AlexNet().to(self.device)
        # self.model = VGG11().to(self.device)
        # self.model = VGG13().to(self.device)
        # self.model = VGG16().to(self.device)
        # self.model = VGG19().to(self.device)
        # self.model = GoogLeNet().to(self.device)
        # self.model = resnet18().to(self.device)
        # self.model = resnet34().to(self.device)
        # self.model = resnet50().to(self.device)
        # self.model = resnet101().to(self.device)
        # self.model = resnet152().to(self.device)
        # self.model = DenseNet121().to(self.device)
        # self.model = DenseNet161().to(self.device)
        # self.model = DenseNet169().to(self.device)
        # self.model = DenseNet201().to(self.device)
        # self.model = WideResNet(depth=28, num_classes=10).to(self.device)

        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer,
                                                        milestones=[75, 150],
                                                        gamma=0.5)
        self.criterion = nn.CrossEntropyLoss().to(self.device)

    def train(self):
        print("train:")
        self.model.train()
        train_loss = 0
        train_correct = 0
        total = 0

        for batch_num, (data, target) in enumerate(self.train_loader):
            data, target = data.to(self.device), target.to(self.device)
            self.optimizer.zero_grad()
            output = self.model(data)
            feature = self.model.feature
            # print('output.shape = {}, target.shape = {}, feature.shape = {}'.format(output.size(), target.size(), feature.size()))
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            prediction = torch.max(
                output,
                1)  # second param "1" represents the dimension to be reduced
            total += target.size(0)

            # train_correct incremented by one if predicted right
            train_correct += np.sum(
                prediction[1].cpu().numpy() == target.cpu().numpy())

            progress_bar(
                batch_num, len(self.train_loader),
                'Loss: %.4f | Acc: %.3f%% (%d/%d)' %
                (train_loss / (batch_num + 1), 100. * train_correct / total,
                 train_correct, total))

        return train_loss, train_correct / total

    def test(self):
        print("test:")
        self.model.eval()
        test_loss = 0
        test_correct = 0
        total = 0

        with torch.no_grad():
            for batch_num, (data, target) in enumerate(self.test_loader):
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)

                # CAM
                # feature = self.model.feature
                # print('feature: {}'.format(feature))

                loss = self.criterion(output, target)
                test_loss += loss.item()
                prediction = torch.max(output, 1)
                total += target.size(0)
                test_correct += np.sum(
                    prediction[1].cpu().numpy() == target.cpu().numpy())

                progress_bar(
                    batch_num, len(self.test_loader),
                    'Loss: %.4f | Acc: %.3f%% (%d/%d)' %
                    (test_loss / (batch_num + 1), 100. * test_correct / total,
                     test_correct, total))

        return test_loss, test_correct / total

    def save(self):
        model_out_path = "model.pth"
        torch.save(self.model, model_out_path)
        print("Checkpoint saved to {}".format(model_out_path))

    def run(self):
        self.load_data()
        print('Success loading data.')
        self.load_model()
        print('Success loading model.')
        accuracy = 0
        for epoch in range(1, self.epochs + 1):
            self.scheduler.step(epoch)
            print("\n===> epoch: %d/200" % epoch)
            train_result = self.train()
            print(train_result)
            test_result = self.test()
            accuracy = max(accuracy, test_result[1])
            if epoch == self.epochs:
                print("===> BEST ACC. PERFORMANCE: %.3f%%" % (accuracy * 100))
                self.save()
def main():
    # Data Loader (Input Pipeline)
    print('loading dataset...')
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=batch_size,
                                               num_workers=args.num_workers,
                                               drop_last=False,
                                               shuffle=False)

    val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                             batch_size=batch_size,
                                             num_workers=args.num_workers,
                                             drop_last=False,
                                             shuffle=False)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=batch_size,
                                              num_workers=args.num_workers,
                                              drop_last=False,
                                              shuffle=False)
    # Define models
    print('building model...')
    if args.dataset == 'mnist':
        clf1 = LeNet()
    if args.dataset == 'fashionmnist':
        clf1 = resnet.ResNet18_F(10)
    if args.dataset == 'cifar10':
        clf1 = resnet.ResNet34(10)
    if args.dataset == 'svhn':
        clf1 = resnet.ResNet34(10)

    clf1.cuda()
    optimizer = torch.optim.SGD(clf1.parameters(),
                                lr=args.lr,
                                weight_decay=args.weight_decay)

    with open(txtfile, "a") as myfile:
        myfile.write('epoch train_acc val_acc test_acc\n')

    epoch = 0
    train_acc = 0
    val_acc = 0
    # evaluate models with random weights
    test_acc = evaluate(test_loader, clf1)
    print('Epoch [%d/%d] Test Accuracy on the %s test data: Model1 %.4f %%' %
          (epoch + 1, args.n_epoch_1, len(test_dataset), test_acc))
    # save results
    with open(txtfile, "a") as myfile:
        myfile.write(
            str(int(epoch)) + ' ' + str(train_acc) + ' ' + str(val_acc) + ' ' +
            str(test_acc) + ' ' + "\n")

    best_acc = 0.0
    # training
    for epoch in range(1, args.n_epoch_1):
        # train models
        clf1.train()
        train_acc = train(clf1, train_loader, epoch, optimizer,
                          nn.CrossEntropyLoss())
        # validation
        val_acc = evaluate(val_loader, clf1)
        # evaluate models
        test_acc = evaluate(test_loader, clf1)

        # save results
        print(
            'Epoch [%d/%d] Train Accuracy on the %s train data: Model %.4f %%'
            % (epoch + 1, args.n_epoch_1, len(train_dataset), train_acc))
        print('Epoch [%d/%d] Val Accuracy on the %s val data: Model %.4f %% ' %
              (epoch + 1, args.n_epoch_1, len(val_dataset), val_acc))
        print(
            'Epoch [%d/%d] Test Accuracy on the %s test data: Model %.4f %% ' %
            (epoch + 1, args.n_epoch_1, len(test_dataset), test_acc))
        with open(txtfile, "a") as myfile:
            myfile.write(
                str(int(epoch)) + ' ' + str(train_acc) + ' ' + str(val_acc) +
                ' ' + str(test_acc) + ' ' + "\n")

        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(clf1.state_dict(), model_save_dir + '/' + 'model.pth')

    print('Matrix Factorization is doing...')
    clf1.load_state_dict(torch.load(model_save_dir + '/' + 'model.pth'))
    A = respresentations_extract(train_loader, clf1, len(train_dataset),
                                 args.dim, batch_size)
    A_val = respresentations_extract(val_loader, clf1, len(val_dataset),
                                     args.dim, batch_size)
    A_total = np.append(A, A_val, axis=0)
    W_total, H_total, error = train_m(A_total, args.basis, args.iteration_nmf,
                                      1e-5)
    for i in range(W_total.shape[0]):
        for j in range(W_total.shape[1]):
            if W_total[i, j] < 1e-6:
                W_total[i, j] = 0.
    W = W_total[0:len(train_dataset), :]
    W_val = W_total[len(train_dataset):, :]
    print('Transition Matrix is estimating...Wating...')
    logits_matrix = probability_extract(train_loader, clf1, len(train_dataset),
                                        args.num_classes, batch_size)
    idx_matrix_group, transition_matrix_group = estimate_matrix(
        logits_matrix, model_save_dir)
    logits_matrix_val = probability_extract(val_loader, clf1, len(val_dataset),
                                            args.num_classes, batch_size)
    idx_matrix_group_val, transition_matrix_group_val = estimate_matrix(
        logits_matrix_val, model_save_dir)
    func = nn.MSELoss()

    model = Matrix_optimize(args.basis, args.num_classes)
    optimizer_1 = torch.optim.Adam(model.parameters(), lr=0.001)
    basis_matrix_group = basis_matrix_optimize(model, optimizer_1, args.basis,
                                               args.num_classes, W,
                                               transition_matrix_group,
                                               idx_matrix_group, func,
                                               model_save_dir, args.n_epoch_4)

    basis_matrix_group_val = basis_matrix_optimize(
        model, optimizer_1, args.basis, args.num_classes, W_val,
        transition_matrix_group_val, idx_matrix_group_val, func,
        model_save_dir, args.n_epoch_4)

    for i in range(basis_matrix_group.shape[0]):
        for j in range(basis_matrix_group.shape[1]):
            for k in range(basis_matrix_group.shape[2]):
                if basis_matrix_group[i, j, k] < 1e-6:
                    basis_matrix_group[i, j, k] = 0.

    optimizer_ = torch.optim.SGD(clf1.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay,
                                 momentum=args.momentum)

    best_acc = 0.0
    for epoch in range(1, args.n_epoch_2):
        # train model
        clf1.train()

        train_acc = train_correction(clf1, train_loader, epoch, optimizer_, W,
                                     basis_matrix_group, batch_size,
                                     args.num_classes, args.basis)
        # validation
        val_acc = val_correction(clf1, val_loader, epoch, W_val,
                                 basis_matrix_group_val, batch_size,
                                 args.num_classes, args.basis)

        # evaluate models
        test_acc = evaluate(test_loader, clf1)
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(clf1.state_dict(), model_save_dir + '/' + 'model.pth')
        with open(txtfile, "a") as myfile:
            myfile.write(
                str(int(epoch)) + ' ' + str(train_acc) + ' ' + str(val_acc) +
                ' ' + str(test_acc) + ' ' + "\n")
        # save results
        print(
            'Epoch [%d/%d] Train Accuracy on the %s train data: Model %.4f %%'
            % (epoch + 1, args.n_epoch_2, len(train_dataset), train_acc))
        print('Epoch [%d/%d] Val Accuracy on the %s val data: Model %.4f %% ' %
              (epoch + 1, args.n_epoch_2, len(val_dataset), val_acc))
        print(
            'Epoch [%d/%d] Test Accuracy on the %s test data: Model %.4f %% ' %
            (epoch + 1, args.n_epoch_2, len(test_dataset), test_acc))

    clf1.load_state_dict(torch.load(model_save_dir + '/' + 'model.pth'))
    optimizer_r = torch.optim.Adam(clf1.parameters(),
                                   lr=args.lr_revision,
                                   weight_decay=args.weight_decay)
    nn.init.constant_(clf1.T_revision.weight, 0.0)

    for epoch in range(1, args.n_epoch_3):
        # train models
        clf1.train()
        train_acc = train_revision(clf1, train_loader, epoch, optimizer_r, W,
                                   basis_matrix_group, batch_size,
                                   args.num_classes, args.basis)
        # validation
        val_acc = val_revision(clf1, val_loader, epoch, W_val,
                               basis_matrix_group, batch_size,
                               args.num_classes, args.basis)
        # evaluate models
        test_acc = evaluate(test_loader, clf1)
        with open(txtfile, "a") as myfile:
            myfile.write(
                str(int(epoch)) + ' ' + str(train_acc) + ' ' + str(val_acc) +
                ' ' + str(test_acc) + ' ' + "\n")

        # save results
        print(
            'Epoch [%d/%d] Train Accuracy on the %s train data: Model %.4f %%'
            % (epoch + 1, args.n_epoch_3, len(train_dataset), train_acc))
        print('Epoch [%d/%d] Val Accuracy on the %s val data: Model %.4f %% ' %
              (epoch + 1, args.n_epoch_3, len(val_dataset), val_acc))
        print(
            'Epoch [%d/%d] Test Accuracy on the %s test data: Model %.4f %% ' %
            (epoch + 1, args.n_epoch_3, len(test_dataset), test_acc))
示例#3
0
    mu = args.min_mu
    mu_max = args.max_mu

    for epoch in range(1, args.epochs + 1):

        # Configure dropout parameters.
        if args.LW_dropout_perc > 0 and args.LW_dropout_delay < epoch:
            useDropout = args.LW_dropout_perc
        else:
            useDropout = 0

        batch_size = args.batch_size + args.delta_batch_size * (epoch - 1)
        print('\nEpoch {} of {}. mu = {:.2f}, batch_size = {}, algorithm = {}'.
              format(epoch, args.epochs, mu, batch_size, algName))

        model.train()
        for batch_idx, (data, targets) in enumerate(train_loader):
            data, targets = data.to(device), targets.to(device)

            if algName == 'altmin':
                #----------------------------------------------------------
                # Set L1 weights according to \mu.
                if args.lambda_c_muFact > 0:
                    epoch_lam_c = args.lambda_c_muFact * mu
                else:
                    epoch_lam_c = args.lambda_c
                if args.lambda_w_muFact > 0:
                    epoch_lam_w = args.lambda_w_muFact * mu
                else:
                    epoch_lam_w = args.lambda_w
    writer_loss = SummaryWriter(gen_path(loss_path))
    writer_acc = SummaryWriter(gen_path(acc_path))

    trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    dataset_train = MNIST('./data/mnist/', train=True, download=True, transform=trans_mnist)
    dataset_test = MNIST('./data/mnist/', train=False, download=True, transform=trans_mnist)
    # sample users
    dict_users = split_noniid_shuffle(dataset_train, args.num_nodes)

    img_size = dataset_train[0][0].shape
    print(img_size)

    net_glob = LeNet().to(args.device)
    print(net_glob.fc1.weight.type())
    print(net_glob)
    net_glob.train()

    # copy weights
    w_glob = net_glob.state_dict()
    w_glob_grad = w_glob

    # training
    #loss_train = []
    
    w_locals = [w_glob for i in range(args.num_nodes)]

    for iter in range(args.epochs):
        loss_locals = []
        for idx in range(args.num_nodes):
            #import pdb; pdb.set_trace()
            local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
class Reptile(object):
    def __init__(self, args):
        self.args = args
        self._load_model()

        self.model.to(args.device)
        self.task_generator = TaskGen(args.max_num_classes)
        self.outer_stepsize = args.outer_stepsize
        self.criterion = nn.CrossEntropyLoss()
        # self.optimizer = optim.Adam(self.model.parameters(), lr=args.inner_stepsize)

    def _load_model(self):
        self.model = LeNet()
        self.current_iteration = 0
        if os.path.exists(self.args.model_path):
            try:
                print("Loading model from: {}".format(self.args.model_path))
                self.model.load_state_dict(torch.load(self.args.model_path))
                self.current_iteration = joblib.load("{}.iter".format(
                    self.args.model_path))
            except Exception as e:
                print(
                    "Exception: {}\nCould not load model from {} - starting from scratch"
                    .format(e, self.args.model_path))

    def inner_training(self, x, y, num_iterations):
        """
        Run training on task
        """
        x, y = shuffle_unison(x, y)

        self.model.train()

        x = torch.tensor(x, dtype=torch.float, device=self.args.device)
        y = torch.tensor(y, dtype=torch.float, device=self.args.device)

        total_loss = 0
        for _ in range(num_iterations):
            start = np.random.randint(0,
                                      len(x) - self.args.inner_batch_size + 1)

            self.model.zero_grad()
            # self.optimizer.zero_grad()
            outputs = self.model(x[start:start + self.args.inner_batch_size])
            # print("output: {} - y: {}".format(outputs.shape, y.shape))
            loss = self.criterion(
                outputs,
                Variable(y[start:start + self.args.inner_batch_size].long()))
            total_loss += loss
            loss.backward()
            # self.optimizer.step()
            # Similar to calling optimizer.step()
            for param in self.model.parameters():
                param.data -= self.args.inner_stepsize * param.grad.data
        return total_loss / self.args.inner_iterations

    def _meta_gradient_update(self, iteration, num_classes, weights_before):
        """
        Interpolate between current weights and trained weights from this task
        I.e. (weights_before - weights_after) is the meta-gradient

            - iteration: current iteration - used for updating outer_stepsize
            - num_classes: current classifier number of classes
            - weights_before: state of weights before inner steps training
        """
        weights_after = self.model.state_dict()
        outer_stepsize = self.outer_stepsize * (
            1 - iteration / self.args.n_iterations)  # linear schedule

        self.model.load_state_dict({
            name: weights_before[name] +
            (weights_after[name] - weights_before[name]) * outer_stepsize
            for name in weights_before
        })

    def meta_training(self):
        # Reptile training loop
        total_loss = 0
        try:
            while self.current_iteration < self.args.n_iterations:
                # Generate task
                data, labels, original_labels, num_classes = self.task_generator.get_train_task(
                    args.num_classes)

                weights_before = deepcopy(self.model.state_dict())
                loss = self.inner_training(data, labels,
                                           self.args.inner_iterations)
                total_loss += loss
                if self.current_iteration % self.args.log_every == 0:
                    print("-----------------------------")
                    print("iteration               {}".format(
                        self.current_iteration + 1))
                    print("Loss: {:.3f}".format(total_loss /
                                                (self.current_iteration + 1)))
                    print("Current task info: ")
                    print("\t- Number of classes: {}".format(num_classes))
                    print("\t- Batch size: {}".format(len(data)))
                    print("\t- Labels: {}".format(set(original_labels)))

                    self.test()

                self._meta_gradient_update(self.current_iteration, num_classes,
                                           weights_before)

                self.current_iteration += 1

            torch.save(self.model.state_dict(), self.args.model_path)

        except KeyboardInterrupt:
            print("Manual Interrupt...")
            print("Saving to: {}".format(self.args.model_path))
            torch.save(self.model.state_dict(), self.args.model_path)
            joblib.dump(self.current_iteration,
                        "{}.iter".format(self.args.model_path),
                        compress=1)

    def predict(self, x):
        self.model.eval()
        x = torch.tensor(x, dtype=torch.float, device=self.args.device)
        outputs = self.model(x)
        return outputs.cpu().data.numpy()

    def test(self):
        """
        Run tests
            1. Create task from test set.
            2. Reload model
            3. Check accuracy on test set
            4. Train for one or more iterations on one task
            5. Check accuracy again on test set
        """

        test_data, test_labels, _, _ = self.task_generator.get_test_task(
            selected_labels=[1, 2, 3, 4,
                             5], num_samples=-1)  # all available samples
        predicted_labels = np.argmax(self.predict(test_data), axis=1)
        accuracy = np.mean(1 * (predicted_labels == test_labels)) * 100
        print(
            "Accuracy before few shots learning (a.k.a. zero-shot learning): {:.2f}%\n----"
            .format(accuracy))

        weights_before = deepcopy(
            self.model.state_dict())  # save snapshot before evaluation
        for i in range(1, 5):
            enroll_data, enroll_labels, _, _ = self.task_generator.get_enroll_task(
                selected_labels=[1, 2, 3, 4, 5], num_samples=i)
            self.inner_training(enroll_data, enroll_labels,
                                self.args.inner_iterations_test)
            predicted_labels = np.argmax(self.predict(test_data), axis=1)
            accuracy = np.mean(1 * (predicted_labels == test_labels)) * 100

            print("Accuracy after {} shot{} learning: {:.2f}%)".format(
                i, "" if i == 1 else "s", accuracy))

        self.model.load_state_dict(weights_before)  # restore from snapshot