Beispiel #1
0
def main():
    if not os.path.isdir(CHECKPOINT):
        os.makedirs(CHECKPOINT)

    print('==> Preparing dataset')

    trainloader, testloader = load_CIFAR(batch_size=BATCH_SIZE,
                                         num_workers=NUM_WORKERS)

    CLASSES = []
    AUROCs = []
    auroc = AverageMeter()

    for t, cls in enumerate(ALL_CLASSES):

        print('\nTask: [%d | %d]\n' % (t + 1, len(ALL_CLASSES)))

        CLASSES = [cls]

        print("==> Creating model")
        model = LeNet(num_classes=1)

        if CUDA:
            model = model.cuda()
            model = nn.DataParallel(model)
            cudnn.benchmark = True

        print('    Total params: %.2fK' %
              (sum(p.numel() for p in model.parameters()) / 1000))

        criterion = nn.BCELoss()
        optimizer = optim.SGD(model.parameters(),
                              lr=LEARNING_RATE,
                              momentum=MOMENTUM,
                              weight_decay=WEIGHT_DECAY)

        print("==> Learning")

        best_loss = 1e10
        learning_rate = LEARNING_RATE

        for epoch in range(EPOCHS):

            # decay learning rate
            if (epoch + 1) % EPOCHS_DROP == 0:
                learning_rate *= LR_DROP
                for param_group in optimizer.param_groups:
                    param_group['lr'] = learning_rate

            print('Epoch: [%d | %d]' % (epoch + 1, EPOCHS))

            train_loss = train(trainloader,
                               model,
                               criterion,
                               CLASSES,
                               CLASSES,
                               optimizer=optimizer,
                               use_cuda=CUDA)
            test_loss = train(testloader,
                              model,
                              criterion,
                              CLASSES,
                              CLASSES,
                              test=True,
                              use_cuda=CUDA)

            # save model
            is_best = test_loss < best_loss
            best_loss = min(test_loss, best_loss)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'loss': test_loss,
                    'optimizer': optimizer.state_dict()
                }, CHECKPOINT, is_best)

        print("==> Calculating AUROC")

        filepath_best = os.path.join(CHECKPOINT, "best.pt")
        checkpoint = torch.load(filepath_best)
        model.load_state_dict(checkpoint['state_dict'])

        new_auroc = calc_avg_AUROC(model, testloader, CLASSES, CLASSES, CUDA)
        auroc.update(new_auroc)

        print('New Task AUROC: {}'.format(new_auroc))
        print('Average AUROC: {}'.format(auroc.avg))

        AUROCs.append(auroc.avg)

    print('\nAverage Per-task Performance over number of tasks')
    for i, p in enumerate(AUROCs):
        print("%d: %f" % (i + 1, p))
def main():
    # Data Loader (Input Pipeline)
    print('loading dataset...')
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=batch_size,
                                               num_workers=args.num_workers,
                                               drop_last=False,
                                               shuffle=False)

    val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                             batch_size=batch_size,
                                             num_workers=args.num_workers,
                                             drop_last=False,
                                             shuffle=False)

    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=batch_size,
                                              num_workers=args.num_workers,
                                              drop_last=False,
                                              shuffle=False)
    # Define models
    print('building model...')
    if args.dataset == 'mnist':
        clf1 = LeNet()
    if args.dataset == 'fashionmnist':
        clf1 = resnet.ResNet18_F(10)
    if args.dataset == 'cifar10':
        clf1 = resnet.ResNet34(10)
    if args.dataset == 'svhn':
        clf1 = resnet.ResNet34(10)

    clf1.cuda()
    optimizer = torch.optim.SGD(clf1.parameters(),
                                lr=args.lr,
                                weight_decay=args.weight_decay)

    with open(txtfile, "a") as myfile:
        myfile.write('epoch train_acc val_acc test_acc\n')

    epoch = 0
    train_acc = 0
    val_acc = 0
    # evaluate models with random weights
    test_acc = evaluate(test_loader, clf1)
    print('Epoch [%d/%d] Test Accuracy on the %s test data: Model1 %.4f %%' %
          (epoch + 1, args.n_epoch_1, len(test_dataset), test_acc))
    # save results
    with open(txtfile, "a") as myfile:
        myfile.write(
            str(int(epoch)) + ' ' + str(train_acc) + ' ' + str(val_acc) + ' ' +
            str(test_acc) + ' ' + "\n")

    best_acc = 0.0
    # training
    for epoch in range(1, args.n_epoch_1):
        # train models
        clf1.train()
        train_acc = train(clf1, train_loader, epoch, optimizer,
                          nn.CrossEntropyLoss())
        # validation
        val_acc = evaluate(val_loader, clf1)
        # evaluate models
        test_acc = evaluate(test_loader, clf1)

        # save results
        print(
            'Epoch [%d/%d] Train Accuracy on the %s train data: Model %.4f %%'
            % (epoch + 1, args.n_epoch_1, len(train_dataset), train_acc))
        print('Epoch [%d/%d] Val Accuracy on the %s val data: Model %.4f %% ' %
              (epoch + 1, args.n_epoch_1, len(val_dataset), val_acc))
        print(
            'Epoch [%d/%d] Test Accuracy on the %s test data: Model %.4f %% ' %
            (epoch + 1, args.n_epoch_1, len(test_dataset), test_acc))
        with open(txtfile, "a") as myfile:
            myfile.write(
                str(int(epoch)) + ' ' + str(train_acc) + ' ' + str(val_acc) +
                ' ' + str(test_acc) + ' ' + "\n")

        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(clf1.state_dict(), model_save_dir + '/' + 'model.pth')

    print('Matrix Factorization is doing...')
    clf1.load_state_dict(torch.load(model_save_dir + '/' + 'model.pth'))
    A = respresentations_extract(train_loader, clf1, len(train_dataset),
                                 args.dim, batch_size)
    A_val = respresentations_extract(val_loader, clf1, len(val_dataset),
                                     args.dim, batch_size)
    A_total = np.append(A, A_val, axis=0)
    W_total, H_total, error = train_m(A_total, args.basis, args.iteration_nmf,
                                      1e-5)
    for i in range(W_total.shape[0]):
        for j in range(W_total.shape[1]):
            if W_total[i, j] < 1e-6:
                W_total[i, j] = 0.
    W = W_total[0:len(train_dataset), :]
    W_val = W_total[len(train_dataset):, :]
    print('Transition Matrix is estimating...Wating...')
    logits_matrix = probability_extract(train_loader, clf1, len(train_dataset),
                                        args.num_classes, batch_size)
    idx_matrix_group, transition_matrix_group = estimate_matrix(
        logits_matrix, model_save_dir)
    logits_matrix_val = probability_extract(val_loader, clf1, len(val_dataset),
                                            args.num_classes, batch_size)
    idx_matrix_group_val, transition_matrix_group_val = estimate_matrix(
        logits_matrix_val, model_save_dir)
    func = nn.MSELoss()

    model = Matrix_optimize(args.basis, args.num_classes)
    optimizer_1 = torch.optim.Adam(model.parameters(), lr=0.001)
    basis_matrix_group = basis_matrix_optimize(model, optimizer_1, args.basis,
                                               args.num_classes, W,
                                               transition_matrix_group,
                                               idx_matrix_group, func,
                                               model_save_dir, args.n_epoch_4)

    basis_matrix_group_val = basis_matrix_optimize(
        model, optimizer_1, args.basis, args.num_classes, W_val,
        transition_matrix_group_val, idx_matrix_group_val, func,
        model_save_dir, args.n_epoch_4)

    for i in range(basis_matrix_group.shape[0]):
        for j in range(basis_matrix_group.shape[1]):
            for k in range(basis_matrix_group.shape[2]):
                if basis_matrix_group[i, j, k] < 1e-6:
                    basis_matrix_group[i, j, k] = 0.

    optimizer_ = torch.optim.SGD(clf1.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay,
                                 momentum=args.momentum)

    best_acc = 0.0
    for epoch in range(1, args.n_epoch_2):
        # train model
        clf1.train()

        train_acc = train_correction(clf1, train_loader, epoch, optimizer_, W,
                                     basis_matrix_group, batch_size,
                                     args.num_classes, args.basis)
        # validation
        val_acc = val_correction(clf1, val_loader, epoch, W_val,
                                 basis_matrix_group_val, batch_size,
                                 args.num_classes, args.basis)

        # evaluate models
        test_acc = evaluate(test_loader, clf1)
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(clf1.state_dict(), model_save_dir + '/' + 'model.pth')
        with open(txtfile, "a") as myfile:
            myfile.write(
                str(int(epoch)) + ' ' + str(train_acc) + ' ' + str(val_acc) +
                ' ' + str(test_acc) + ' ' + "\n")
        # save results
        print(
            'Epoch [%d/%d] Train Accuracy on the %s train data: Model %.4f %%'
            % (epoch + 1, args.n_epoch_2, len(train_dataset), train_acc))
        print('Epoch [%d/%d] Val Accuracy on the %s val data: Model %.4f %% ' %
              (epoch + 1, args.n_epoch_2, len(val_dataset), val_acc))
        print(
            'Epoch [%d/%d] Test Accuracy on the %s test data: Model %.4f %% ' %
            (epoch + 1, args.n_epoch_2, len(test_dataset), test_acc))

    clf1.load_state_dict(torch.load(model_save_dir + '/' + 'model.pth'))
    optimizer_r = torch.optim.Adam(clf1.parameters(),
                                   lr=args.lr_revision,
                                   weight_decay=args.weight_decay)
    nn.init.constant_(clf1.T_revision.weight, 0.0)

    for epoch in range(1, args.n_epoch_3):
        # train models
        clf1.train()
        train_acc = train_revision(clf1, train_loader, epoch, optimizer_r, W,
                                   basis_matrix_group, batch_size,
                                   args.num_classes, args.basis)
        # validation
        val_acc = val_revision(clf1, val_loader, epoch, W_val,
                               basis_matrix_group, batch_size,
                               args.num_classes, args.basis)
        # evaluate models
        test_acc = evaluate(test_loader, clf1)
        with open(txtfile, "a") as myfile:
            myfile.write(
                str(int(epoch)) + ' ' + str(train_acc) + ' ' + str(val_acc) +
                ' ' + str(test_acc) + ' ' + "\n")

        # save results
        print(
            'Epoch [%d/%d] Train Accuracy on the %s train data: Model %.4f %%'
            % (epoch + 1, args.n_epoch_3, len(train_dataset), train_acc))
        print('Epoch [%d/%d] Val Accuracy on the %s val data: Model %.4f %% ' %
              (epoch + 1, args.n_epoch_3, len(val_dataset), val_acc))
        print(
            'Epoch [%d/%d] Test Accuracy on the %s test data: Model %.4f %% ' %
            (epoch + 1, args.n_epoch_3, len(test_dataset), test_acc))