def run_alexnet_ann_recall_test_simulation_trial3():
    # instantiate alexnet from mnist trained
    alex_cnn = AlexNet()
    alex_cnn.load_state_dict(torch.load("trained_models/alexnet.pt", map_location=torch.device("cpu")))
    alex_cnn.eval()
    alex_capture = Intermediate_Capture(alex_cnn.fc1) # for now capture final output
    return run_alexnet_ann_recall_simulation(alex_cnn=alex_cnn, alex_capture=alex_capture, output_name="alexnet_recall_task_trial3.txt", num_nodes=1024)
예제 #2
0
        out = model(batch_x)
        loss = loss_func(out, batch_y)
        train_loss += loss.item()
        pred = torch.max(out, 1)[1]
        train_correct = (pred == batch_y).sum()
        train_accu += train_correct.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    mean_loss = train_loss / (len(train_data))
    mean_accu = train_accu / (len(train_data))
    #print(mean_loss,mean_accu)
    print('Training Loss : %.6f,Accu: %.6f' % (mean_loss, mean_accu))

    #evaluation------------------------
    model.eval()
    eval_loss = 0
    eval_accu = 0
    for batch_x, batch_y in test_loader:
        batch_x, batch_y = Variable(batch_x, volatile=True).cuda(), Variable(
            batch_y, volatile=True).cuda()
        out = model(batch_x)
        loss = loss_func(out, batch_y)
        eval_loss += loss.data[0]
        pred = torch.max(out, 1)[1]
        num_correct = (pred == batch_y).sum()
        eval_accu += num_correct.data[0]
    mean_loss = eval_loss / (len(test_data))
    mean_accu = eval_accu / (len(test_data))
    print('Testing Loss:%.6f,Accu:%.6f' % (mean_loss, mean_accu))
예제 #3
0
def test_sample_images(path_to_model, path_to_images, save_path):
    num_classes = 27

    # Load pre learned AlexNet
    state_dict = torch.load(path_to_model,
                            map_location=lambda storage, loc: storage)['model']
    model = AlexNet(num_classes)
    model.load_state_dict(state_dict)
    model.eval()

    # Process every image
    dictionary = set(nltk.corpus.words.words())
    distances = defaultdict(lambda: defaultdict(lambda: 0))
    size_distances = defaultdict(lambda: defaultdict(lambda: 0))
    corrected_words = defaultdict(lambda: defaultdict(lambda: 0))
    with open('{}labels.txt'.format(path_to_images)) as f:
        for line in f:
            sections = line.split('; ')
            if len(sections) < 2:
                continue
            fname = sections[0]
            correct_word = sections[1]

            # Open image
            image = cv2.imread('{}{}'.format(path_to_images, fname))
            output = image

            # Find bounding boxes for each character
            image = preprocess_image(image)
            _, image = cv2.threshold(image, 90, 255, cv2.THRESH_BINARY_INV)
            bounding_boxes = find_bounding_boxes(image)
            bounding_boxes = filter_bounding_boxes(image, bounding_boxes)

            # Find 5 most probable results
            subimages = extract_characters(image, bounding_boxes)
            results = classify_characters(model, subimages)
            results = results[:5]

            # Check if word can be corrected
            corrected_word = ''
            for word in results:
                if word[0].lower() in dictionary and corrected_word is '':
                    corrected_word = word[0]

            # Append to evaluation dicts for evaluation
            most_probable_word = results[0][0]
            distance = Levenshtein.distance(most_probable_word, correct_word)
            distances[len(correct_word)][distance] += 1
            size_distances[len(correct_word)][len(most_probable_word)] += 1

            corrected_words[len(correct_word)][0] += 1
            if corrected_word == correct_word:
                corrected_words[len(correct_word)][1] += 1

            # Print information about current progress
            print(
                'Correct: {:12s}  Most probable: {:12s}  Corrected: {:12s}  Distance: {:1d}  Success: {}'
                .format(correct_word, most_probable_word, corrected_word,
                        distance, corrected_word == correct_word))

    #  Save results
    with open('{}/test_results_distance.txt'.format(save_path), 'w') as f:
        for size in sorted(distances):
            for distance in sorted(distances[size]):
                f.write('{};{};{}\n'.format(size, distance,
                                            distances[size][distance]))

    with open('{}/test_results_size.txt'.format(save_path), 'w') as f:
        for size in sorted(size_distances):
            for size_distance in sorted(size_distances[size]):
                f.write('{};{};{}\n'.format(
                    size, size_distance, size_distances[size][size_distance]))

    with open('{}/test_results_corrected.txt'.format(save_path), 'w') as f:
        for key in sorted(corrected_words):
            for count in sorted(corrected_words[key]):
                f.write('{};{};{}\n'.format(key, count,
                                            corrected_words[key][count]))
예제 #4
0
def train(train_loader, eval_loader, opt):
    print('==> Start training...')

    summary_writer = SummaryWriter('./runs/' + str(int(time.time())))

    is_cuda = torch.cuda.is_available()
    model = AlexNet()
    if is_cuda:
        model = model.cuda()

    optimizer = optim.SGD(
        params=model.parameters(),
        lr=opt.base_lr,
        momentum=0.9,
    )
    criterion = nn.CrossEntropyLoss()

    best_eval_acc = -0.1
    losses = AverageMeter()
    accuracies = AverageMeter()
    global_step = 0
    for epoch in range(1, opt.epochs + 1):
        # train
        model.train()
        for batch_idx, (inputs, targets) in enumerate(train_loader):
            global_step += 1
            if is_cuda:
                inputs = inputs.cuda()
                targets = targets.cuda()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            losses.update(loss.item(), outputs.shape[0])
            summary_writer.add_scalar('train/loss', loss, global_step)

            _, preds = torch.max(outputs, dim=1)
            acc = preds.eq(targets).sum().item() / len(targets)
            accuracies.update(acc)
            summary_writer.add_scalar('train/acc', acc, global_step)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        summary_writer.add_scalar('lr', optimizer.param_groups[0]['lr'],
                                  global_step)
        print(
            '==> Epoch: %d; Average Train Loss: %.4f; Average Train Acc: %.4f'
            % (epoch, losses.avg, accuracies.avg))

        # eval
        model.eval()
        losses.reset()
        accuracies.reset()
        for batch_idx, (inputs, targets) in enumerate(eval_loader):
            if is_cuda:
                inputs = inputs.cuda()
                targets = targets.cuda()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            losses.update(loss.item(), outputs.shape[0])

            _, preds = torch.max(outputs, dim=1)
            acc = preds.eq(targets).sum().item() / len(targets)
            accuracies.update(acc)

        summary_writer.add_scalar('eval/loss', losses.avg, global_step)
        summary_writer.add_scalar('eval/acc', accuracies.avg, global_step)
        if accuracies.avg > best_eval_acc:
            best_eval_acc = accuracies.avg
            torch.save(model, './weights/best.pt')
        print(
            '==> Epoch: %d; Average Eval Loss: %.4f; Average/Best Eval Acc: %.4f / %.4f'
            % (epoch, losses.avg, accuracies.avg, best_eval_acc))
예제 #5
0
class TestNetwork():
    def __init__(self, dataset, batch_size, epochs):
        self.dataset = dataset
        self.batch_size = batch_size
        self.epochs = epochs

        # letters contains 27 classes, digits contains 10 classes
        num_classes = 27 if dataset == 'letters' else 10

        # Load mdoel and use cuda if available
        self.model = AlexNet(num_classes)
        if torch.cuda.is_available():
            self.model.cuda()

        # Load testing dataset
        kwargs = {
            'num_workers': 1,
            'pin_memory': True
        } if torch.cuda.is_available() else {}
        self.test_loader = torch.utils.data.DataLoader(EMNIST(
            './data',
            dataset,
            download=True,
            transform=transforms.Compose([
                transforms.Lambda(correct_rotation),
                transforms.Resize((224, 224)),
                transforms.Grayscale(3),
                transforms.ToTensor(),
            ]),
            train=False),
                                                       batch_size=batch_size,
                                                       shuffle=True,
                                                       **kwargs)

        # Optimizer and loss function
        self.loss_fn = nn.CrossEntropyLoss()

    def test(self, epoch):
        """
        Test the model for one epoch with a pre trained network
        :param epoch: Current epoch
        :return: None
        """
        # Load weights from trained model
        state_dict = torch.load(
            './trained_models/{}_{}.pth'.format(self.dataset, epoch),
            map_location=lambda storage, loc: storage)['model']
        self.model.load_state_dict(state_dict)
        self.model.eval()

        test_loss = 0
        test_correct = 0
        progress = None
        for batch_idx, (data, target) in enumerate(self.test_loader):
            # Get data and label
            if torch.cuda.is_available():
                data, target = data.cuda(), target.cuda()
            data, target = Variable(data), Variable(target)

            #
            output = self.model(data)
            loss = self.loss_fn(output, target)
            test_loss += loss.data[0]
            pred = output.data.max(1, keepdim=True)[1]
            test_correct += pred.eq(target.data.view_as(pred)).sum()

            # Print information about current step
            current_progress = int(100 * (batch_idx + 1) * self.batch_size /
                                   len(self.test_loader.dataset))
            if current_progress is not progress and current_progress % 5 == 0:
                progress = current_progress
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, (batch_idx + 1) * len(data),
                    len(self.test_loader.dataset), current_progress,
                    loss.data[0]))

        test_loss /= (len(self.test_loader.dataset) / self.batch_size)
        test_correct /= len(self.test_loader.dataset)
        test_correct *= 100

        # Print information about current epoch
        print(
            'Test Epoch: {} \tCorrect: {:3.2f}%\tAverage loss: {:.6f}'.format(
                epoch, test_correct, test_loss))

    def start(self):
        """
        Start testing the network
        :return: None
        """
        for epoch in range(1, self.epochs + 1):
            self.test(epoch)
예제 #6
0
import torch
from torch.nn.functional import softmax

from alexnet import AlexNet
from utils import cifar10_loader, device, cifar10_classes

torch.random.manual_seed(128)
batch_size = 1
testloader = cifar10_loader(train=False, batch_size=batch_size)

net = AlexNet()
net.load_state_dict(torch.load("model/model.h5"))
net.eval()

correct = 0
total = 0


def run():
    global correct, total
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            inputs, labels = images.to(device), labels.to(device)
            outputs = net(inputs)
            _, predicted = torch.topk(outputs.data, 5)
            #print(predicted)
            indexes = predicted.numpy()[0].tolist()
            #print(indexes)
            #print(softmax(outputs).numpy()[0][indexes])
            #print([cifar10_classes[i] for i in indexes])
def run_alexnet_ann_recall_test_simulation_trial7():
    output_name="alexnet_recall_task_trial7.txt"
    num_nodes=10
    full_connection_mat = np.ones(shape=(num_nodes,num_nodes)) - np.eye(num_nodes)
    alex_cnn = AlexNet()
    alex_cnn.load_state_dict(torch.load("trained_models/alexnet.pt", map_location=torch.device("cpu")))
    alex_cnn.eval()
    alex_capture = Intermediate_Capture(alex_cnn.fc3) # for now capture final output

    transform = transforms.ToTensor()
    data_raw = MNIST(
    root='./data/mnist',
    train=True,
    download=True,
    transform=transform)

    # creating a toy dataset for simple probing
    mnist_subset = {0:[], 1:[], 2:[], 3:[], 4:[], 5:[], 6:[], 7:[], 8:[], 9:[]}
    per_class_sizes = {0:10, 1:10, 2:10, 3:10, 4:10, 5:10, 6:10, 7:10, 8:10, 9:10}
    for i in range(len(data_raw)):
        image, label = data_raw[i]
        if len(mnist_subset[label]) < per_class_sizes[label]:
            mnist_subset[label].append(torch.reshape(image, (1,1, 28,28)))
        done = True
        for k in mnist_subset:
            if len(mnist_subset[k]) < per_class_sizes[k]:
                done=False
        if done:
            break


    # converts mnist_subset into table that is usable for model input
    full_pattern_set = []
    full_label_set = []
    for k in mnist_subset:
        for v in mnist_subset[k]:
            full_pattern_set.append(v)
            full_label_set.append(k)

    # given list of a desired labels, randomly choose an example of each label from the mnist dataset to store
    stored_size_vs_performance = [] # list will store tuples of (hopfield perf, popularity perf, ortho perf)
    for desired_label_size in range(10):

        # need to generate probe set each time
        # when desired label size is k:
        # probe set is 10 instances each of labels 0 to k-1
        desired_labels = list(range(desired_label_size+1))
        sub_probe_set = []
        sub_probe_labels = []
        for des in desired_labels:
            # add 10 instances of des
            for inst in mnist_subset[des]:
                sub_probe_set.append(inst)
                sub_probe_labels.append(des)
        full_stored_set, full_stored_labels = create_storage_set(desired_labels, mnist_subset, reshape=False, make_numpy=False)
        print("Num Stored: ", len(desired_labels))

        # evaluate hopnet performance
        ann_model = hopnet(num_nodes) 
        model = CNN_ANN(alex_cnn, ann_model, alex_capture, capture_process_fn=lambda x: np.sign(np.exp(x)-np.exp(x).mean()))
        num_succ, num_fail = evaluate_model_recall(model, full_stored_set, full_stored_labels, sub_probe_set, sub_probe_labels, verbose=False)
        print("Hopfield:", num_succ, ":", num_fail)
        hopfield_perf = int(num_succ)

        # evaluate popularity ANN performance
        # hyperparams: set c = N-1, with randomly generated connectivity matrix
        ann_model = PopularityANN(N=num_nodes, c=num_nodes-1, connectivity_matrix=full_connection_mat)
        model = CNN_ANN(alex_cnn, ann_model, alex_capture, capture_process_fn=lambda x: np.sign(np.exp(x)-np.exp(x).mean()))
        num_succ, num_fail = evaluate_model_recall(model, full_stored_set, full_stored_labels, sub_probe_set, sub_probe_labels, verbose=False)
        print("PopularityANN:", num_succ, ":", num_fail)
        popularity_perf = int(num_succ)

        # evaluate orthogonal hebbs ANN performance
        ann_model = OrthogonalHebbsANN(N=num_nodes)
        model = CNN_ANN(alex_cnn, ann_model, alex_capture, capture_process_fn=lambda x: np.sign(np.exp(x)-np.exp(x).mean()))
        num_succ, num_fail = evaluate_model_recall(model, full_stored_set, full_stored_labels, sub_probe_set, sub_probe_labels, verbose=False)
        print("OrthogonalHebbsANN:", num_succ, ":", num_fail)
        ortho_perf = int(num_succ)

        stored_size_vs_performance.append((hopfield_perf, popularity_perf, ortho_perf))

    # write performance to file
    fh = open("data/graph_sources/" + output_name, "w")
    for perf in stored_size_vs_performance:
        fh.write(str(perf[0]) + "," + str(perf[1]) + "," + str(perf[2]) + "\n")
    fh.close()
    return stored_size_vs_performance
def run_alexnet_ann_recall_test_simulation_trial4():
    num_nodes = 10
    alex_cnn1 = AlexNet()
    alex_cnn1.load_state_dict(torch.load("trained_models/alexnet.pt", map_location=torch.device("cpu")))
    alex_cnn1.eval()
    alex_capture1 = Intermediate_Capture(alex_cnn1.layer3) # for now capture final output
    output_name = "alexnet_recall_task_trial4.txt"

    alex_cnn2 = AlexNet()
    alex_cnn2.load_state_dict(torch.load("trained_models/alexnet.pt", map_location=torch.device("cpu")))
    alex_cnn2.eval()
    alex_capture2 = Intermediate_Capture(alex_cnn2.layer4) # for now capture final output

    alex_cnn3 = AlexNet()
    alex_cnn3.load_state_dict(torch.load("trained_models/alexnet.pt", map_location=torch.device("cpu")))
    alex_cnn3.eval()
    alex_capture3 = Intermediate_Capture(alex_cnn3.layer5) # for now capture final output

    transform = transforms.ToTensor()
    data_raw = MNIST(
    root='./data/mnist',
    train=True,
    download=True,
    transform=transform)

    # creating a toy dataset for simple probing
    mnist_subset = {0:[], 1:[], 2:[], 3:[], 4:[], 5:[], 6:[], 7:[], 8:[], 9:[]}
    per_class_sizes = {0:10, 1:10, 2:10, 3:10, 4:10, 5:10, 6:10, 7:10, 8:10, 9:10}
    for i in range(len(data_raw)):
        image, label = data_raw[i]
        if len(mnist_subset[label]) < per_class_sizes[label]:
            mnist_subset[label].append(torch.reshape(image, (1,1, 28,28)))
        done = True
        for k in mnist_subset:
            if len(mnist_subset[k]) < per_class_sizes[k]:
                done=False
        if done:
            break


    # converts mnist_subset into table that is usable for model input
    full_pattern_set = []
    full_label_set = []
    for k in mnist_subset:
        for v in mnist_subset[k]:
            full_pattern_set.append(v)
            full_label_set.append(k)

    # given list of a desired labels, randomly choose an example of each label from the mnist dataset to store
    stored_size_vs_performance = [] # list will store tuples of (hopfield perf, popularity perf, ortho perf)
    for desired_label_size in range(10):
        desired_labels = list(range(desired_label_size+1))
        full_stored_set, full_stored_labels = create_storage_set(desired_labels, mnist_subset, reshape=False, make_numpy=False)
        print("Num Stored: ", len(desired_labels))

        # evaluate hopnet performance
        ann_model = hopnet(6272) 
        model = CNN_ANN(alex_cnn1, ann_model, alex_capture1, capture_process_fn=lambda x: np.sign(np.exp(x)-np.exp(x).mean()))
        num_succ, num_fail = evaluate_model_recall(model, full_stored_set, full_stored_labels, full_stored_set, full_stored_labels, verbose=False)
        print("Alexnet Layer3:", num_succ, ":", num_fail)
        layer3_perf = int(num_succ)

        ann_model = hopnet(12544) 
        model = CNN_ANN(alex_cnn2, ann_model, alex_capture2, capture_process_fn=lambda x: np.sign(np.exp(x)-np.exp(x).mean()))
        num_succ, num_fail = evaluate_model_recall(model, full_stored_set, full_stored_labels, full_stored_set, full_stored_labels, verbose=False)
        print("Alexnet Layer3:", num_succ, ":", num_fail)
        layer4_perf = int(num_succ)

        ann_model = hopnet(2304) 
        model = CNN_ANN(alex_cnn3, ann_model, alex_capture3, capture_process_fn=lambda x: np.sign(np.exp(x)-np.exp(x).mean()))
        num_succ, num_fail = evaluate_model_recall(model, full_stored_set, full_stored_labels, full_stored_set, full_stored_labels, verbose=False)
        print("Alexnet Layer3:", num_succ, ":", num_fail)
        layer5_perf = int(num_succ)

        stored_size_vs_performance.append((layer3_perf, layer4_perf, layer5_perf))

    # write performance to file
    fh = open("data/graph_sources/" + output_name, "w")
    for perf in stored_size_vs_performance:
        fh.write(str(perf[0]) + "," + str(perf[1]) + "," + str(perf[2]) + "\n")
    fh.close()
    return stored_size_vs_performance
예제 #9
0
class LiveShowcase:
    def __init__(self, path_to_model):
        num_classes = 27

        # Member variables
        self.status = 'Ready'
        self.last_words = None
        self.dictionary_set = set(nltk.corpus.words.words())

        # Load pre learned AlexNet
        state_dict = torch.load(path_to_model, map_location=lambda storage, loc: storage)['model']
        self.model = AlexNet(num_classes)
        self.model.load_state_dict(state_dict)
        self.model.eval()

    def process_image(self, image, bounding_boxes):
        """
        Process image to find and classify characters and build 5 most probable words
        :param image: rgb image
        :param bounding_boxes: list of bounding boxes containing characters (min_x, min_y, width, height)
        :return: None
        """
        self.status = 'Processing'

        # Find 5 most probable words
        subimages = extract_characters(image, bounding_boxes)
        words = classify_characters(self.model, subimages)
        self.last_words = words[:5]

        self.status = 'Ready'

    def start(self, max_bounding_boxes=10):
        """
        Start the live showcase using a camera
        :return: None
        """
        # Try to open a connection to the camera
        cap = cv2.VideoCapture(0)
        if not cap.isOpened():
            print('Error: No camera found')
            return
        cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280)
        cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 960)

        print('Press q to stop the live showcase')
        while True:
            # Capture frame-by-frame
            ret, image = cap.read()
            output = image

            # Find bounding boxes for each character
            image = preprocess_image(image)
            bounding_boxes = find_bounding_boxes(image)
            bounding_boxes = filter_bounding_boxes(image, bounding_boxes)
            for box in bounding_boxes:
                cv2.rectangle(output, (box[0], box[1]), (box[0] + box[2], box[1] + box[3]), (0, 0, 255), 2)

            # Process image if no other image is processed
            if self.status.__contains__('Ready'):
                if len(bounding_boxes) > max_bounding_boxes:
                    self.status = 'Ready [Warning: too many bounding boxes]'
                    self.last_words = None
                else:
                    thread = threading.Thread(target=self.process_image, args=(image, bounding_boxes), daemon=True)
                    thread.start()

            # Draw status bar with last recognized words
            cv2.putText(output, 'Status: {}'.format(self.status), (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
                        (0, 0, 0), 1, cv2.LINE_AA)
            if self.last_words:
                for offset, word in zip(range(len(self.last_words)), self.last_words):
                    color = (0, 0, 0)
                    # Use green color if word is in dictionary
                    if word[0].lower() in self.dictionary_set:
                        color = (0, 255, 0)
                    cv2.putText(output, '{} ({:5.2f}%)'.format(word[0], 100 * word[1]),
                                (10, 20 + (offset + 1) * 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5,
                                color, 1, cv2.LINE_AA)

            # Draw bounding box around detected word
            if self.last_words and len(bounding_boxes) > 1:
                word = self.last_words[0]
                color = (0, 0, 255)
                # Use green color if word is in dictionary
                if word[0].lower() in self.dictionary_set:
                    color = (0, 255, 0)
                text = '{} ({:5.2f}%)'.format(word[0], 100 * word[1])
                padding = 10
                top_left = (np.min([b[0] for b in bounding_boxes]) - padding,
                            np.min([b[1] for b in bounding_boxes]) - padding)
                bottom_right = (np.max([b[0]+b[2] for b in bounding_boxes]) + padding,
                                np.max([b[1]+b[3] for b in bounding_boxes]) + padding)
                text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0]
                cv2.rectangle(output, (top_left[0] - 1, top_left[1] - text_size[1] - 2 * padding),
                                          (top_left[0] + text_size[0] + 2 * padding, top_left[1]),
                              color, thickness=cv2.FILLED)
                cv2.putText(output, text, (top_left[0] + padding, top_left[1] - padding),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
                cv2.rectangle(output, top_left, bottom_right, color, 2)

            # Display the resulting frame
            cv2.imshow('Image internal', image)
            cv2.imshow('Showcase', output)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break

        # When everything done, release the capture
        cap.release()
        cv2.destroyAllWindows()
예제 #10
0
class Solver(object):
    def __init__(self, config):
        self.model = None
        self.name = config.name
        self.lr = config.lr
        self.momentum = config.momentum
        self.beta = config.beta
        self.max_alpha = config.max_alpha
        self.epochs = config.epochs
        self.patience = config.patience
        self.N = config.N
        self.batch_size = config.batch_size
        self.random_labels = config.random_labels
        self.use_bn = config.batchnorm
        self.criterion = None
        self.optimizer = None
        self.scheduler = None
        self.device = None
        self.cuda = config.cuda
        self.train_loader = None
        self.test_loader = None

    def load_data(self):
        # ToTensor scales pixel values from [0,255] to [0,1]
        mean_var = (125.3 / 255, 123.0 / 255,
                    113.9 / 255), (63.0 / 255, 62.1 / 255, 66.7 / 255)
        transform = transforms.Compose([
            transforms.CenterCrop(28),
            transforms.ToTensor(),
            transforms.Normalize(*mean_var, inplace=True)
        ])
        train_set = torchvision.datasets.CIFAR10(root='./data',
                                                 train=True,
                                                 download=DOWNLOAD,
                                                 transform=transform)
        test_set = torchvision.datasets.CIFAR10(root='./data',
                                                train=False,
                                                download=DOWNLOAD,
                                                transform=transform)

        if self.random_labels:
            np.random.shuffle(train_set.targets)
            np.random.shuffle(test_set.targets)

        assert self.N <= 50000
        if self.N < 50000:
            train_set.data = train_set.data[:self.N]
            # downsize the test set to improve speed for small N
            test_set.data = test_set.data[:self.N]

        self.train_loader = torch.utils.data.DataLoader(
            dataset=train_set,
            batch_size=self.batch_size,
            shuffle=True,
            drop_last=True)
        self.test_loader = torch.utils.data.DataLoader(
            dataset=test_set,
            batch_size=self.batch_size,
            shuffle=False,
            drop_last=True)

    def load_model(self):
        if self.cuda:
            self.device = torch.device('cuda')
            cudnn.benchmark = True
        else:
            self.device = torch.device('cpu')

        self.model = AlexNet(device=self.device,
                             B=self.batch_size,
                             max_alpha=self.max_alpha,
                             use_bn=self.use_bn).to(self.device)

        self.optimizer = optim.SGD(self.model.parameters(),
                                   lr=self.lr,
                                   momentum=self.momentum)
        self.scheduler = optim.lr_scheduler.StepLR(self.optimizer,
                                                   step_size=140)
        self.criterion = nn.NLLLoss().to(self.device)

    def getIw(self):
        # Iw should be normalized with respect to N
        # via reparameterization, we optimize alpha with only 1920 dimensions
        # but Iw should scale with the dimension of the weights
        return 7 * 7 * 64 * 384 / 1920 * self.model.getIw() / self.batch_size

    def do_batch(self, train, epoch):
        loader = self.train_loader if train else self.test_loader
        total_ce, total_Iw, total_loss = 0, 0, 0
        total_correct = 0
        total = 0
        pbar = tqdm(loader)
        num_batches = len(loader)
        for batch_num, (data, target) in enumerate(pbar):
            data, target = data.to(self.device), target.to(self.device)
            if train:
                self.optimizer.zero_grad()
            output = self.model(data)
            # NLLLoss is averaged across observations for each minibatch
            ce = self.criterion(torch.log(output + EPS), target)
            Iw = self.getIw()
            loss = ce + 0.5 * self.beta * Iw
            if train:
                loss.backward()
                self.optimizer.step()
            total_ce += ce.item()
            total_Iw += Iw.item()
            total_loss += loss.item()
            prediction = torch.max(
                output,
                1)  # second param "1" represents the dimension to be reduced
            total_correct += np.sum(
                prediction[1].cpu().numpy() == target.cpu().numpy())
            total += target.size(0)

            a = self.model.get_a()
            pbar.set_description('Train' if train else 'Test')
            pbar.set_postfix(N=self.N,
                             b=self.beta,
                             ep=epoch,
                             acc=100. * total_correct / total,
                             loss=total_loss / num_batches,
                             ce=total_ce / num_batches,
                             Iw=total_Iw / num_batches,
                             a=a)
        return total_correct / total, total_loss / num_batches, total_ce / num_batches, total_Iw / num_batches, a

    def train(self, epoch):
        self.model.train()
        return self.do_batch(train=True, epoch=epoch)

    def test(self, epoch):
        self.model.eval()
        with torch.no_grad():
            return self.do_batch(train=False, epoch=epoch)

    def save(self, name=None):
        model_out_path = (name or self.name) + ".pth"
        # torch.save(self.model, model_out_path)
        # print("Checkpoint saved to {}".format(model_out_path))

    def run(self):
        self.load_data()
        self.load_model()
        results = []
        best_acc, best_ep = -1, -1
        for epoch in range(1, self.epochs + 1):
            # print("\n===> epoch: %d/200" % epoch)
            train_acc, train_loss, train_ce, train_Iw, train_a = self.train(
                epoch)
            self.scheduler.step(epoch)
            test_acc, test_loss, test_ce, test_Iw, test_a = self.test(epoch)
            results.append([
                self.N, self.beta, train_acc, test_acc, train_loss, test_loss,
                train_ce, test_ce, train_Iw, test_Iw, train_a, test_a
            ])

            if test_acc > best_acc:
                best_acc, best_ep = test_acc, epoch
            if self.patience >= 0:  # early stopping
                if best_ep < epoch - self.patience:
                    break

        with open(self.name + '.csv', 'a') as f:
            w = csv.writer(f)
            w.writerows(results)
        self.save()

        return train_acc, test_acc
예제 #11
0
        structure_loss = -torch.sum(torch.mul(fake_eig_vecs, real_eig_vecs), 0)
        normalized_real_eig_vals = normalize_min_max(real_eig_vals)
        weighted_structure_loss = torch.sum(
            torch.mul(normalized_real_eig_vals, structure_loss))
        return magnitude_loss + weighted_structure_loss

    netG = Generator(ngpu).to(device)
    netG.apply(weights_init)
    if opt.netG != '':
        netG.load_state_dict(torch.load(opt.netG))
    print(netG)

    netC = AlexNet(ngpu).to(device)
    netC.load_state_dict(torch.load('./best_model.pth'))
    print(netC)
    netC.eval()

    netD = Discriminator(ngpu).to(device)
    netD.apply(weights_init)
    if opt.netD != '':
        netD.load_state_dict(torch.load(opt.netD))
    print(netD)

    criterion = nn.BCELoss()
    criterion_sum = nn.BCELoss(reduction='sum')

    fixed_noise = torch.randn(opt.batchSize, 100, 1, 1, device=device)

    real_label = 1
    fake_label = 0