def run_alexnet_ann_recall_test_simulation_trial3(): # instantiate alexnet from mnist trained alex_cnn = AlexNet() alex_cnn.load_state_dict(torch.load("trained_models/alexnet.pt", map_location=torch.device("cpu"))) alex_cnn.eval() alex_capture = Intermediate_Capture(alex_cnn.fc1) # for now capture final output return run_alexnet_ann_recall_simulation(alex_cnn=alex_cnn, alex_capture=alex_capture, output_name="alexnet_recall_task_trial3.txt", num_nodes=1024)
out = model(batch_x) loss = loss_func(out, batch_y) train_loss += loss.item() pred = torch.max(out, 1)[1] train_correct = (pred == batch_y).sum() train_accu += train_correct.item() optimizer.zero_grad() loss.backward() optimizer.step() mean_loss = train_loss / (len(train_data)) mean_accu = train_accu / (len(train_data)) #print(mean_loss,mean_accu) print('Training Loss : %.6f,Accu: %.6f' % (mean_loss, mean_accu)) #evaluation------------------------ model.eval() eval_loss = 0 eval_accu = 0 for batch_x, batch_y in test_loader: batch_x, batch_y = Variable(batch_x, volatile=True).cuda(), Variable( batch_y, volatile=True).cuda() out = model(batch_x) loss = loss_func(out, batch_y) eval_loss += loss.data[0] pred = torch.max(out, 1)[1] num_correct = (pred == batch_y).sum() eval_accu += num_correct.data[0] mean_loss = eval_loss / (len(test_data)) mean_accu = eval_accu / (len(test_data)) print('Testing Loss:%.6f,Accu:%.6f' % (mean_loss, mean_accu))
def test_sample_images(path_to_model, path_to_images, save_path): num_classes = 27 # Load pre learned AlexNet state_dict = torch.load(path_to_model, map_location=lambda storage, loc: storage)['model'] model = AlexNet(num_classes) model.load_state_dict(state_dict) model.eval() # Process every image dictionary = set(nltk.corpus.words.words()) distances = defaultdict(lambda: defaultdict(lambda: 0)) size_distances = defaultdict(lambda: defaultdict(lambda: 0)) corrected_words = defaultdict(lambda: defaultdict(lambda: 0)) with open('{}labels.txt'.format(path_to_images)) as f: for line in f: sections = line.split('; ') if len(sections) < 2: continue fname = sections[0] correct_word = sections[1] # Open image image = cv2.imread('{}{}'.format(path_to_images, fname)) output = image # Find bounding boxes for each character image = preprocess_image(image) _, image = cv2.threshold(image, 90, 255, cv2.THRESH_BINARY_INV) bounding_boxes = find_bounding_boxes(image) bounding_boxes = filter_bounding_boxes(image, bounding_boxes) # Find 5 most probable results subimages = extract_characters(image, bounding_boxes) results = classify_characters(model, subimages) results = results[:5] # Check if word can be corrected corrected_word = '' for word in results: if word[0].lower() in dictionary and corrected_word is '': corrected_word = word[0] # Append to evaluation dicts for evaluation most_probable_word = results[0][0] distance = Levenshtein.distance(most_probable_word, correct_word) distances[len(correct_word)][distance] += 1 size_distances[len(correct_word)][len(most_probable_word)] += 1 corrected_words[len(correct_word)][0] += 1 if corrected_word == correct_word: corrected_words[len(correct_word)][1] += 1 # Print information about current progress print( 'Correct: {:12s} Most probable: {:12s} Corrected: {:12s} Distance: {:1d} Success: {}' .format(correct_word, most_probable_word, corrected_word, distance, corrected_word == correct_word)) # Save results with open('{}/test_results_distance.txt'.format(save_path), 'w') as f: for size in sorted(distances): for distance in sorted(distances[size]): f.write('{};{};{}\n'.format(size, distance, distances[size][distance])) with open('{}/test_results_size.txt'.format(save_path), 'w') as f: for size in sorted(size_distances): for size_distance in sorted(size_distances[size]): f.write('{};{};{}\n'.format( size, size_distance, size_distances[size][size_distance])) with open('{}/test_results_corrected.txt'.format(save_path), 'w') as f: for key in sorted(corrected_words): for count in sorted(corrected_words[key]): f.write('{};{};{}\n'.format(key, count, corrected_words[key][count]))
def train(train_loader, eval_loader, opt): print('==> Start training...') summary_writer = SummaryWriter('./runs/' + str(int(time.time()))) is_cuda = torch.cuda.is_available() model = AlexNet() if is_cuda: model = model.cuda() optimizer = optim.SGD( params=model.parameters(), lr=opt.base_lr, momentum=0.9, ) criterion = nn.CrossEntropyLoss() best_eval_acc = -0.1 losses = AverageMeter() accuracies = AverageMeter() global_step = 0 for epoch in range(1, opt.epochs + 1): # train model.train() for batch_idx, (inputs, targets) in enumerate(train_loader): global_step += 1 if is_cuda: inputs = inputs.cuda() targets = targets.cuda() outputs = model(inputs) loss = criterion(outputs, targets) losses.update(loss.item(), outputs.shape[0]) summary_writer.add_scalar('train/loss', loss, global_step) _, preds = torch.max(outputs, dim=1) acc = preds.eq(targets).sum().item() / len(targets) accuracies.update(acc) summary_writer.add_scalar('train/acc', acc, global_step) optimizer.zero_grad() loss.backward() optimizer.step() summary_writer.add_scalar('lr', optimizer.param_groups[0]['lr'], global_step) print( '==> Epoch: %d; Average Train Loss: %.4f; Average Train Acc: %.4f' % (epoch, losses.avg, accuracies.avg)) # eval model.eval() losses.reset() accuracies.reset() for batch_idx, (inputs, targets) in enumerate(eval_loader): if is_cuda: inputs = inputs.cuda() targets = targets.cuda() outputs = model(inputs) loss = criterion(outputs, targets) losses.update(loss.item(), outputs.shape[0]) _, preds = torch.max(outputs, dim=1) acc = preds.eq(targets).sum().item() / len(targets) accuracies.update(acc) summary_writer.add_scalar('eval/loss', losses.avg, global_step) summary_writer.add_scalar('eval/acc', accuracies.avg, global_step) if accuracies.avg > best_eval_acc: best_eval_acc = accuracies.avg torch.save(model, './weights/best.pt') print( '==> Epoch: %d; Average Eval Loss: %.4f; Average/Best Eval Acc: %.4f / %.4f' % (epoch, losses.avg, accuracies.avg, best_eval_acc))
class TestNetwork(): def __init__(self, dataset, batch_size, epochs): self.dataset = dataset self.batch_size = batch_size self.epochs = epochs # letters contains 27 classes, digits contains 10 classes num_classes = 27 if dataset == 'letters' else 10 # Load mdoel and use cuda if available self.model = AlexNet(num_classes) if torch.cuda.is_available(): self.model.cuda() # Load testing dataset kwargs = { 'num_workers': 1, 'pin_memory': True } if torch.cuda.is_available() else {} self.test_loader = torch.utils.data.DataLoader(EMNIST( './data', dataset, download=True, transform=transforms.Compose([ transforms.Lambda(correct_rotation), transforms.Resize((224, 224)), transforms.Grayscale(3), transforms.ToTensor(), ]), train=False), batch_size=batch_size, shuffle=True, **kwargs) # Optimizer and loss function self.loss_fn = nn.CrossEntropyLoss() def test(self, epoch): """ Test the model for one epoch with a pre trained network :param epoch: Current epoch :return: None """ # Load weights from trained model state_dict = torch.load( './trained_models/{}_{}.pth'.format(self.dataset, epoch), map_location=lambda storage, loc: storage)['model'] self.model.load_state_dict(state_dict) self.model.eval() test_loss = 0 test_correct = 0 progress = None for batch_idx, (data, target) in enumerate(self.test_loader): # Get data and label if torch.cuda.is_available(): data, target = data.cuda(), target.cuda() data, target = Variable(data), Variable(target) # output = self.model(data) loss = self.loss_fn(output, target) test_loss += loss.data[0] pred = output.data.max(1, keepdim=True)[1] test_correct += pred.eq(target.data.view_as(pred)).sum() # Print information about current step current_progress = int(100 * (batch_idx + 1) * self.batch_size / len(self.test_loader.dataset)) if current_progress is not progress and current_progress % 5 == 0: progress = current_progress print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, (batch_idx + 1) * len(data), len(self.test_loader.dataset), current_progress, loss.data[0])) test_loss /= (len(self.test_loader.dataset) / self.batch_size) test_correct /= len(self.test_loader.dataset) test_correct *= 100 # Print information about current epoch print( 'Test Epoch: {} \tCorrect: {:3.2f}%\tAverage loss: {:.6f}'.format( epoch, test_correct, test_loss)) def start(self): """ Start testing the network :return: None """ for epoch in range(1, self.epochs + 1): self.test(epoch)
import torch from torch.nn.functional import softmax from alexnet import AlexNet from utils import cifar10_loader, device, cifar10_classes torch.random.manual_seed(128) batch_size = 1 testloader = cifar10_loader(train=False, batch_size=batch_size) net = AlexNet() net.load_state_dict(torch.load("model/model.h5")) net.eval() correct = 0 total = 0 def run(): global correct, total with torch.no_grad(): for data in testloader: images, labels = data inputs, labels = images.to(device), labels.to(device) outputs = net(inputs) _, predicted = torch.topk(outputs.data, 5) #print(predicted) indexes = predicted.numpy()[0].tolist() #print(indexes) #print(softmax(outputs).numpy()[0][indexes]) #print([cifar10_classes[i] for i in indexes])
def run_alexnet_ann_recall_test_simulation_trial7(): output_name="alexnet_recall_task_trial7.txt" num_nodes=10 full_connection_mat = np.ones(shape=(num_nodes,num_nodes)) - np.eye(num_nodes) alex_cnn = AlexNet() alex_cnn.load_state_dict(torch.load("trained_models/alexnet.pt", map_location=torch.device("cpu"))) alex_cnn.eval() alex_capture = Intermediate_Capture(alex_cnn.fc3) # for now capture final output transform = transforms.ToTensor() data_raw = MNIST( root='./data/mnist', train=True, download=True, transform=transform) # creating a toy dataset for simple probing mnist_subset = {0:[], 1:[], 2:[], 3:[], 4:[], 5:[], 6:[], 7:[], 8:[], 9:[]} per_class_sizes = {0:10, 1:10, 2:10, 3:10, 4:10, 5:10, 6:10, 7:10, 8:10, 9:10} for i in range(len(data_raw)): image, label = data_raw[i] if len(mnist_subset[label]) < per_class_sizes[label]: mnist_subset[label].append(torch.reshape(image, (1,1, 28,28))) done = True for k in mnist_subset: if len(mnist_subset[k]) < per_class_sizes[k]: done=False if done: break # converts mnist_subset into table that is usable for model input full_pattern_set = [] full_label_set = [] for k in mnist_subset: for v in mnist_subset[k]: full_pattern_set.append(v) full_label_set.append(k) # given list of a desired labels, randomly choose an example of each label from the mnist dataset to store stored_size_vs_performance = [] # list will store tuples of (hopfield perf, popularity perf, ortho perf) for desired_label_size in range(10): # need to generate probe set each time # when desired label size is k: # probe set is 10 instances each of labels 0 to k-1 desired_labels = list(range(desired_label_size+1)) sub_probe_set = [] sub_probe_labels = [] for des in desired_labels: # add 10 instances of des for inst in mnist_subset[des]: sub_probe_set.append(inst) sub_probe_labels.append(des) full_stored_set, full_stored_labels = create_storage_set(desired_labels, mnist_subset, reshape=False, make_numpy=False) print("Num Stored: ", len(desired_labels)) # evaluate hopnet performance ann_model = hopnet(num_nodes) model = CNN_ANN(alex_cnn, ann_model, alex_capture, capture_process_fn=lambda x: np.sign(np.exp(x)-np.exp(x).mean())) num_succ, num_fail = evaluate_model_recall(model, full_stored_set, full_stored_labels, sub_probe_set, sub_probe_labels, verbose=False) print("Hopfield:", num_succ, ":", num_fail) hopfield_perf = int(num_succ) # evaluate popularity ANN performance # hyperparams: set c = N-1, with randomly generated connectivity matrix ann_model = PopularityANN(N=num_nodes, c=num_nodes-1, connectivity_matrix=full_connection_mat) model = CNN_ANN(alex_cnn, ann_model, alex_capture, capture_process_fn=lambda x: np.sign(np.exp(x)-np.exp(x).mean())) num_succ, num_fail = evaluate_model_recall(model, full_stored_set, full_stored_labels, sub_probe_set, sub_probe_labels, verbose=False) print("PopularityANN:", num_succ, ":", num_fail) popularity_perf = int(num_succ) # evaluate orthogonal hebbs ANN performance ann_model = OrthogonalHebbsANN(N=num_nodes) model = CNN_ANN(alex_cnn, ann_model, alex_capture, capture_process_fn=lambda x: np.sign(np.exp(x)-np.exp(x).mean())) num_succ, num_fail = evaluate_model_recall(model, full_stored_set, full_stored_labels, sub_probe_set, sub_probe_labels, verbose=False) print("OrthogonalHebbsANN:", num_succ, ":", num_fail) ortho_perf = int(num_succ) stored_size_vs_performance.append((hopfield_perf, popularity_perf, ortho_perf)) # write performance to file fh = open("data/graph_sources/" + output_name, "w") for perf in stored_size_vs_performance: fh.write(str(perf[0]) + "," + str(perf[1]) + "," + str(perf[2]) + "\n") fh.close() return stored_size_vs_performance
def run_alexnet_ann_recall_test_simulation_trial4(): num_nodes = 10 alex_cnn1 = AlexNet() alex_cnn1.load_state_dict(torch.load("trained_models/alexnet.pt", map_location=torch.device("cpu"))) alex_cnn1.eval() alex_capture1 = Intermediate_Capture(alex_cnn1.layer3) # for now capture final output output_name = "alexnet_recall_task_trial4.txt" alex_cnn2 = AlexNet() alex_cnn2.load_state_dict(torch.load("trained_models/alexnet.pt", map_location=torch.device("cpu"))) alex_cnn2.eval() alex_capture2 = Intermediate_Capture(alex_cnn2.layer4) # for now capture final output alex_cnn3 = AlexNet() alex_cnn3.load_state_dict(torch.load("trained_models/alexnet.pt", map_location=torch.device("cpu"))) alex_cnn3.eval() alex_capture3 = Intermediate_Capture(alex_cnn3.layer5) # for now capture final output transform = transforms.ToTensor() data_raw = MNIST( root='./data/mnist', train=True, download=True, transform=transform) # creating a toy dataset for simple probing mnist_subset = {0:[], 1:[], 2:[], 3:[], 4:[], 5:[], 6:[], 7:[], 8:[], 9:[]} per_class_sizes = {0:10, 1:10, 2:10, 3:10, 4:10, 5:10, 6:10, 7:10, 8:10, 9:10} for i in range(len(data_raw)): image, label = data_raw[i] if len(mnist_subset[label]) < per_class_sizes[label]: mnist_subset[label].append(torch.reshape(image, (1,1, 28,28))) done = True for k in mnist_subset: if len(mnist_subset[k]) < per_class_sizes[k]: done=False if done: break # converts mnist_subset into table that is usable for model input full_pattern_set = [] full_label_set = [] for k in mnist_subset: for v in mnist_subset[k]: full_pattern_set.append(v) full_label_set.append(k) # given list of a desired labels, randomly choose an example of each label from the mnist dataset to store stored_size_vs_performance = [] # list will store tuples of (hopfield perf, popularity perf, ortho perf) for desired_label_size in range(10): desired_labels = list(range(desired_label_size+1)) full_stored_set, full_stored_labels = create_storage_set(desired_labels, mnist_subset, reshape=False, make_numpy=False) print("Num Stored: ", len(desired_labels)) # evaluate hopnet performance ann_model = hopnet(6272) model = CNN_ANN(alex_cnn1, ann_model, alex_capture1, capture_process_fn=lambda x: np.sign(np.exp(x)-np.exp(x).mean())) num_succ, num_fail = evaluate_model_recall(model, full_stored_set, full_stored_labels, full_stored_set, full_stored_labels, verbose=False) print("Alexnet Layer3:", num_succ, ":", num_fail) layer3_perf = int(num_succ) ann_model = hopnet(12544) model = CNN_ANN(alex_cnn2, ann_model, alex_capture2, capture_process_fn=lambda x: np.sign(np.exp(x)-np.exp(x).mean())) num_succ, num_fail = evaluate_model_recall(model, full_stored_set, full_stored_labels, full_stored_set, full_stored_labels, verbose=False) print("Alexnet Layer3:", num_succ, ":", num_fail) layer4_perf = int(num_succ) ann_model = hopnet(2304) model = CNN_ANN(alex_cnn3, ann_model, alex_capture3, capture_process_fn=lambda x: np.sign(np.exp(x)-np.exp(x).mean())) num_succ, num_fail = evaluate_model_recall(model, full_stored_set, full_stored_labels, full_stored_set, full_stored_labels, verbose=False) print("Alexnet Layer3:", num_succ, ":", num_fail) layer5_perf = int(num_succ) stored_size_vs_performance.append((layer3_perf, layer4_perf, layer5_perf)) # write performance to file fh = open("data/graph_sources/" + output_name, "w") for perf in stored_size_vs_performance: fh.write(str(perf[0]) + "," + str(perf[1]) + "," + str(perf[2]) + "\n") fh.close() return stored_size_vs_performance
class LiveShowcase: def __init__(self, path_to_model): num_classes = 27 # Member variables self.status = 'Ready' self.last_words = None self.dictionary_set = set(nltk.corpus.words.words()) # Load pre learned AlexNet state_dict = torch.load(path_to_model, map_location=lambda storage, loc: storage)['model'] self.model = AlexNet(num_classes) self.model.load_state_dict(state_dict) self.model.eval() def process_image(self, image, bounding_boxes): """ Process image to find and classify characters and build 5 most probable words :param image: rgb image :param bounding_boxes: list of bounding boxes containing characters (min_x, min_y, width, height) :return: None """ self.status = 'Processing' # Find 5 most probable words subimages = extract_characters(image, bounding_boxes) words = classify_characters(self.model, subimages) self.last_words = words[:5] self.status = 'Ready' def start(self, max_bounding_boxes=10): """ Start the live showcase using a camera :return: None """ # Try to open a connection to the camera cap = cv2.VideoCapture(0) if not cap.isOpened(): print('Error: No camera found') return cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280) cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 960) print('Press q to stop the live showcase') while True: # Capture frame-by-frame ret, image = cap.read() output = image # Find bounding boxes for each character image = preprocess_image(image) bounding_boxes = find_bounding_boxes(image) bounding_boxes = filter_bounding_boxes(image, bounding_boxes) for box in bounding_boxes: cv2.rectangle(output, (box[0], box[1]), (box[0] + box[2], box[1] + box[3]), (0, 0, 255), 2) # Process image if no other image is processed if self.status.__contains__('Ready'): if len(bounding_boxes) > max_bounding_boxes: self.status = 'Ready [Warning: too many bounding boxes]' self.last_words = None else: thread = threading.Thread(target=self.process_image, args=(image, bounding_boxes), daemon=True) thread.start() # Draw status bar with last recognized words cv2.putText(output, 'Status: {}'.format(self.status), (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA) if self.last_words: for offset, word in zip(range(len(self.last_words)), self.last_words): color = (0, 0, 0) # Use green color if word is in dictionary if word[0].lower() in self.dictionary_set: color = (0, 255, 0) cv2.putText(output, '{} ({:5.2f}%)'.format(word[0], 100 * word[1]), (10, 20 + (offset + 1) * 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1, cv2.LINE_AA) # Draw bounding box around detected word if self.last_words and len(bounding_boxes) > 1: word = self.last_words[0] color = (0, 0, 255) # Use green color if word is in dictionary if word[0].lower() in self.dictionary_set: color = (0, 255, 0) text = '{} ({:5.2f}%)'.format(word[0], 100 * word[1]) padding = 10 top_left = (np.min([b[0] for b in bounding_boxes]) - padding, np.min([b[1] for b in bounding_boxes]) - padding) bottom_right = (np.max([b[0]+b[2] for b in bounding_boxes]) + padding, np.max([b[1]+b[3] for b in bounding_boxes]) + padding) text_size = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0] cv2.rectangle(output, (top_left[0] - 1, top_left[1] - text_size[1] - 2 * padding), (top_left[0] + text_size[0] + 2 * padding, top_left[1]), color, thickness=cv2.FILLED) cv2.putText(output, text, (top_left[0] + padding, top_left[1] - padding), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA) cv2.rectangle(output, top_left, bottom_right, color, 2) # Display the resulting frame cv2.imshow('Image internal', image) cv2.imshow('Showcase', output) if cv2.waitKey(1) & 0xFF == ord('q'): break # When everything done, release the capture cap.release() cv2.destroyAllWindows()
class Solver(object): def __init__(self, config): self.model = None self.name = config.name self.lr = config.lr self.momentum = config.momentum self.beta = config.beta self.max_alpha = config.max_alpha self.epochs = config.epochs self.patience = config.patience self.N = config.N self.batch_size = config.batch_size self.random_labels = config.random_labels self.use_bn = config.batchnorm self.criterion = None self.optimizer = None self.scheduler = None self.device = None self.cuda = config.cuda self.train_loader = None self.test_loader = None def load_data(self): # ToTensor scales pixel values from [0,255] to [0,1] mean_var = (125.3 / 255, 123.0 / 255, 113.9 / 255), (63.0 / 255, 62.1 / 255, 66.7 / 255) transform = transforms.Compose([ transforms.CenterCrop(28), transforms.ToTensor(), transforms.Normalize(*mean_var, inplace=True) ]) train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=DOWNLOAD, transform=transform) test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=DOWNLOAD, transform=transform) if self.random_labels: np.random.shuffle(train_set.targets) np.random.shuffle(test_set.targets) assert self.N <= 50000 if self.N < 50000: train_set.data = train_set.data[:self.N] # downsize the test set to improve speed for small N test_set.data = test_set.data[:self.N] self.train_loader = torch.utils.data.DataLoader( dataset=train_set, batch_size=self.batch_size, shuffle=True, drop_last=True) self.test_loader = torch.utils.data.DataLoader( dataset=test_set, batch_size=self.batch_size, shuffle=False, drop_last=True) def load_model(self): if self.cuda: self.device = torch.device('cuda') cudnn.benchmark = True else: self.device = torch.device('cpu') self.model = AlexNet(device=self.device, B=self.batch_size, max_alpha=self.max_alpha, use_bn=self.use_bn).to(self.device) self.optimizer = optim.SGD(self.model.parameters(), lr=self.lr, momentum=self.momentum) self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=140) self.criterion = nn.NLLLoss().to(self.device) def getIw(self): # Iw should be normalized with respect to N # via reparameterization, we optimize alpha with only 1920 dimensions # but Iw should scale with the dimension of the weights return 7 * 7 * 64 * 384 / 1920 * self.model.getIw() / self.batch_size def do_batch(self, train, epoch): loader = self.train_loader if train else self.test_loader total_ce, total_Iw, total_loss = 0, 0, 0 total_correct = 0 total = 0 pbar = tqdm(loader) num_batches = len(loader) for batch_num, (data, target) in enumerate(pbar): data, target = data.to(self.device), target.to(self.device) if train: self.optimizer.zero_grad() output = self.model(data) # NLLLoss is averaged across observations for each minibatch ce = self.criterion(torch.log(output + EPS), target) Iw = self.getIw() loss = ce + 0.5 * self.beta * Iw if train: loss.backward() self.optimizer.step() total_ce += ce.item() total_Iw += Iw.item() total_loss += loss.item() prediction = torch.max( output, 1) # second param "1" represents the dimension to be reduced total_correct += np.sum( prediction[1].cpu().numpy() == target.cpu().numpy()) total += target.size(0) a = self.model.get_a() pbar.set_description('Train' if train else 'Test') pbar.set_postfix(N=self.N, b=self.beta, ep=epoch, acc=100. * total_correct / total, loss=total_loss / num_batches, ce=total_ce / num_batches, Iw=total_Iw / num_batches, a=a) return total_correct / total, total_loss / num_batches, total_ce / num_batches, total_Iw / num_batches, a def train(self, epoch): self.model.train() return self.do_batch(train=True, epoch=epoch) def test(self, epoch): self.model.eval() with torch.no_grad(): return self.do_batch(train=False, epoch=epoch) def save(self, name=None): model_out_path = (name or self.name) + ".pth" # torch.save(self.model, model_out_path) # print("Checkpoint saved to {}".format(model_out_path)) def run(self): self.load_data() self.load_model() results = [] best_acc, best_ep = -1, -1 for epoch in range(1, self.epochs + 1): # print("\n===> epoch: %d/200" % epoch) train_acc, train_loss, train_ce, train_Iw, train_a = self.train( epoch) self.scheduler.step(epoch) test_acc, test_loss, test_ce, test_Iw, test_a = self.test(epoch) results.append([ self.N, self.beta, train_acc, test_acc, train_loss, test_loss, train_ce, test_ce, train_Iw, test_Iw, train_a, test_a ]) if test_acc > best_acc: best_acc, best_ep = test_acc, epoch if self.patience >= 0: # early stopping if best_ep < epoch - self.patience: break with open(self.name + '.csv', 'a') as f: w = csv.writer(f) w.writerows(results) self.save() return train_acc, test_acc
structure_loss = -torch.sum(torch.mul(fake_eig_vecs, real_eig_vecs), 0) normalized_real_eig_vals = normalize_min_max(real_eig_vals) weighted_structure_loss = torch.sum( torch.mul(normalized_real_eig_vals, structure_loss)) return magnitude_loss + weighted_structure_loss netG = Generator(ngpu).to(device) netG.apply(weights_init) if opt.netG != '': netG.load_state_dict(torch.load(opt.netG)) print(netG) netC = AlexNet(ngpu).to(device) netC.load_state_dict(torch.load('./best_model.pth')) print(netC) netC.eval() netD = Discriminator(ngpu).to(device) netD.apply(weights_init) if opt.netD != '': netD.load_state_dict(torch.load(opt.netD)) print(netD) criterion = nn.BCELoss() criterion_sum = nn.BCELoss(reduction='sum') fixed_noise = torch.randn(opt.batchSize, 100, 1, 1, device=device) real_label = 1 fake_label = 0