print('Model restored! Epoch:', i) start_epoch = i + 1 break if start_epoch == 0: assert False, "could not resume" if args.ngpu > 1: net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu))) if args.ngpu > 0: net.cuda() torch.cuda.manual_seed(1) cudnn.benchmark = True # fire on all cylinders net.eval() concat = lambda x: np.concatenate(x, axis=0) to_np = lambda x: x.data.to('cpu').numpy() def evaluate(loader): confidence = [] correct = [] num_correct = 0 with torch.no_grad(): for data, target in loader: data, target = data.cuda(), target.cuda() output = net(2 * data - 1)
def __init__(self, root='~/home-nfs/dan/cifar_data', train=True, gold=True, gold_fraction=0.1, corruption_prob=0, corruption_type='unif', transform=None, target_transform=None, download=False, shuffle_indices=None, distinguish_gold=True, seed=1): self.root = root self.transform = transform self.target_transform = target_transform self.train = train # training set or test set self.gold = gold self.gold_fraction = gold_fraction self.corruption_prob = corruption_prob if download: self.download() if not self._check_integrity(): raise RuntimeError('Dataset not found or corrupted.' + ' You can use download=True to download it') # now load the picked numpy arrays if self.train: self.train_data = [] self.train_labels = [] self.train_coarse_labels = [] for fentry in self.train_list: f = fentry[0] file = os.path.join(root, self.base_folder, f) fo = open(file, 'rb') if sys.version_info[0] == 2: entry = pickle.load(fo) else: entry = pickle.load(fo, encoding='latin1') self.train_data.append(entry['data']) if 'labels' in entry: self.train_labels += entry['labels'] num_classes = 10 else: self.train_labels += entry['fine_labels'] self.train_coarse_labels += entry['coarse_labels'] num_classes = 100 fo.close() self.train_data = np.concatenate(self.train_data) self.train_data = self.train_data.reshape((50000, 3, 32, 32)) self.train_data = self.train_data.transpose( (0, 2, 3, 1)) # convert to HWC if gold is True: if shuffle_indices is None: indices = np.arange(50000) shuffled_train_labels = self.train_labels if self.gold_fraction >= 0.05: while len( set(shuffled_train_labels[:int(gold_fraction * 50000)]) ) < num_classes: np.random.shuffle(indices) shuffled_train_labels = list( np.array(self.train_labels)[indices]) else: gold_indices = [] for c in range(num_classes): gold_indices.extend( indices[np.asarray(self.train_labels) == c][:int(self.gold_fraction * 50000 / num_classes)]) indices = np.array( gold_indices + list(set(range(50000)) - set(gold_indices))) else: indices = shuffle_indices self.train_data = self.train_data[indices][:int(gold_fraction * 50000)] if distinguish_gold: # this ad-hoc move is done so we can identify which examples are # gold/trusted and which are silver/unstrusted self.train_labels = list( np.array(self.train_labels)[indices] [:int(gold_fraction * 50000)] + num_classes) else: self.train_labels = list( np.array( self.train_labels)[indices][:int(gold_fraction * 50000)]) self.shuffle_indices = indices else: indices = np.arange( len(self.train_data )) if shuffle_indices is None else shuffle_indices self.train_data = self.train_data[indices][int(gold_fraction * 50000):] self.train_labels = list( np.array(self.train_labels)[indices][int(gold_fraction * 50000):]) if corruption_type == 'hierarchical': self.train_coarse_labels = list( np.array(self.train_coarse_labels)[indices] [int(gold_fraction * 50000):]) if corruption_type == 'unif': C = uniform_mix_C(self.corruption_prob, num_classes) elif corruption_type == 'flip': C = flip_labels_C(self.corruption_prob, num_classes, seed=seed) elif corruption_type == 'hierarchical': assert num_classes == 100, 'You must use CIFAR-100 with the hierarchical corruption.' coarse_fine = [] for i in range(20): coarse_fine.append(set()) for i in range(len(self.train_labels)): coarse_fine[self.train_coarse_labels[i]].add( self.train_labels[i]) for i in range(20): coarse_fine[i] = list(coarse_fine[i]) C = np.eye(num_classes) * (1 - corruption_prob) for i in range(20): tmp = np.copy(coarse_fine[i]) for j in range(len(tmp)): tmp2 = np.delete(np.copy(tmp), j) C[tmp[j], tmp2] += corruption_prob * 1 / len(tmp2) elif corruption_type == 'clabels': net = WideResNet(40, num_classes, 2, dropRate=0.3).cuda() model_name = './cifar{}_labeler'.format(num_classes) net.load_state_dict(torch.load(model_name)) net.eval() else: assert False, "Invalid corruption type '{}' given. Must be in {'unif', 'flip', 'hierarchical'}".format( corruption_type) np.random.seed(seed) if corruption_type == 'clabels': mean = [x / 255 for x in [125.3, 123.0, 113.9]] std = [x / 255 for x in [63.0, 62.1, 66.7]] test_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean, std) ]) # obtain sampling probabilities sampling_probs = [] print('Starting labeling') for i in range((len(self.train_labels) // 64) + 1): current = self.train_data[i * 64:(i + 1) * 64] current = [ Image.fromarray(current[i]) for i in range(len(current)) ] current = torch.cat([ test_transform(current[i]).unsqueeze(0) for i in range(len(current)) ], dim=0) data = V(current).cuda() logits = net(data) smax = F.softmax(logits / 5) # temperature of 1 sampling_probs.append(smax.data.cpu().numpy()) sampling_probs = np.concatenate(sampling_probs, 0) print('Finished labeling 1') new_labeling_correct = 0 argmax_labeling_correct = 0 for i in range(len(self.train_labels)): old_label = self.train_labels[i] new_label = np.random.choice(num_classes, p=sampling_probs[i]) self.train_labels[i] = new_label if old_label == new_label: new_labeling_correct += 1 if old_label == np.argmax(sampling_probs[i]): argmax_labeling_correct += 1 print('Finished labeling 2') print('New labeling accuracy:', new_labeling_correct / len(self.train_labels)) print('Argmax labeling accuracy:', argmax_labeling_correct / len(self.train_labels)) else: for i in range(len(self.train_labels)): self.train_labels[i] = np.random.choice( num_classes, p=C[self.train_labels[i]]) self.corruption_matrix = C else: f = self.test_list[0][0] file = os.path.join(root, self.base_folder, f) fo = open(file, 'rb') if sys.version_info[0] == 2: entry = pickle.load(fo) else: entry = pickle.load(fo, encoding='latin1') self.test_data = entry['data'] if 'labels' in entry: self.test_labels = entry['labels'] else: self.test_labels = entry['fine_labels'] fo.close() self.test_data = self.test_data.reshape((10000, 3, 32, 32)) self.test_data = self.test_data.transpose( (0, 2, 3, 1)) # convert to HWC