def __init__(self, trunk_name, output_stride=8, pretrained=True): super(get_resnet, self).__init__() if trunk_name == "resnet18": resnet = resnet18(pretrained=pretrained) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk_name == "resnet34": resnet = resnet34(pretrained=pretrained) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk_name == "resnet50": resnet = resnet50(pretrained=pretrained) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk_name == "resnet101": resnet = resnet101(pretrained=pretrained) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk_name == "resnet152": resnet = resnet152(pretrained=pretrained) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk_name == "resnext101_32x8d": resnet = resnext101_32x8d(pretrained=pretrained) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk_name == "se_resnet50": resnet = se_resnext50_32x4d(pretrained=pretrained) elif trunk_name == "se_resnet101": resnet = se_resnext101_32x4d(pretrained=pretrained) else: raise KeyError("[*] Not a valid network arch") self.layer0 = resnet.layer0 self.layer1, self.layer2, self.layer3, self.layer4 = \ resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4 if output_stride == 8: for n, m in self.layer3.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) elif output_stride == 16: for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) else: raise KeyError( "[*] Unsupported output_stride {}".format(output_stride))
def __init__(self, num_classes=256, input_bits=8, embed_dim=8, seq_len=100, backbone='resnet50'): super(CNN, self).__init__() self.input_bits = input_bits self.seq_len = seq_len # self.embed = nn.Embedding(256, 256) self.embed = nn.Embedding(2 ** input_bits, embed_dim) self.backbone = { 'resnet18': resnet18(n_classes=num_classes, input_channels=embed_dim), 'resnet34': resnet34(n_classes=num_classes, input_channels=embed_dim), 'resnet50': resnet50(n_classes=num_classes, input_channels=embed_dim), 'resnet101': resnet101(n_classes=num_classes, input_channels=embed_dim), 'resnet152': resnet152(n_classes=num_classes, input_channels=embed_dim), }.get(backbone, 'resnet50')
def __init__(self, args, train_transform=None, test_transform=None, val_transform=None): self.args = args self.train_transform = train_transform self.test_transform = test_transform self.val_transform = val_transform self.loss_lams = torch.zeros(args.num_classes, args.num_classes, dtype=torch.float32).cuda() self.loss_lams[:, :] = 1. / args.num_classes self.loss_lams.requires_grad = False if self.args.arch == 'resnet50': self.net = resnet.resnet50(num_classes=args.num_classes) elif self.args.arch == 'res2net50': self.net = res2net.res2net50_26w_4s(num_classes=args.num_classes) elif self.args.arch == 'resnet101': self.net = resnet.resnet101() elif self.args.arch == 'resnet18': self.net = resnet.resnet18() elif self.args.arch == 'resnet34': self.net = resnet.resnet34() elif self.args.arch == 'vgg16': self.net = vgg_cub.vgg16() elif self.args.arch == 'vgg16_bn': self.net = vgg_cub.vgg16_bn() elif self.args.arch == 'vgg19': self.net = vgg_cub.vgg19() elif self.args.arch == 'vgg19_bn': self.net = vgg_cub.vgg19_bn() elif self.args.arch == 'vgg16_std': self.net = vgg_std.vgg16() elif self.args.arch == 'vgg16_bn_std': self.net = vgg_std.vgg16_bn() elif self.args.arch == 'mobilenetv2': self.net = mobilenetv2.mobilenet_v2(num_classes=args.num_classes) if self.args.load_model is not None: self.net.load_state_dict(torch.load(self.args.load_model), strict=True) print('load model from %s' % self.args.load_model) elif self.args.pretrained_model is not None: self.net.load_state_dict(torch.load(self.args.pretrained_model), strict=False) print('load pretrained model form %s' % self.args.pretrained_model) else: print('not load any model, will train from scrach!') if args.expname is None: args.expname = 'runs/{}_{}_{}'.format(args.arch, args.dataset, args.method) os.makedirs(args.expname, exist_ok=True)
def testRed(num, visualize): # Set Test parameters params = TestParams() params.gpus = [ ] # set 'params.gpus=[]' to use CPU model. if len(params.gpus)>1, default to use params.gpus[0] to test # this model corresponds to the model trained in the 60th epoch shown in the two training results under ./architecture: # red_*_train_1e5_test_2e4_10_kinds_3min_per_epoch_resnet18.png params.ckpt = './models/formation_prediction.pth' # models # model = resnet34(pretrained=False, num_classes=1000) # batch_size=120, 1GPU Memory < 7000M # model.fc = nn.Linear(512, 6) model = resnet18(pretrained=False, num_classes=1000) # batch_size=60, 1GPU Memory > 9000M model.fc = nn.Linear(512, formation_num) # Test tester = RedTester(model, params) tester.test(num, visualize)
def main(): MAX_EPOCH = 2 print('=' * 3) sys.exit(0) model = ResNet.resnet18() model.build(input_shape=(None, 32, 32, 3)) model.summary() optimizer = optimizers.Adam(lr=1e-3) print('=' * 6) sys.exit(0) for epoch in range(MAX_EPOCH): for step, (x, y) in enumerate(train_data): with tf.GradientTape() as tape: logits = model(x) y_onehot = tf.one_hot(y, depth=10) loss = tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True) loss = tf.reduce_mean(loss) grads = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(grads, model.trainable_variables)) if step % 100 == 0: print(epoch, step, 'loss', float(loss)) total_num = 0 total_correct = 0 for x, y in test_data: logits = model(x) prob = tf.nn.softmax(logits, axis=1) pred = tf.argmax(prob, axis=1) pred = tf.cast(pred, dtype=tf.int32) correct = tf.cast(tf.equal(pred, y), dtype=tf.int32) correct = tf.reduce_sum(correct) total_num += x.shape[0] total_correct += int(correct) acc = total_correct / total_num print(epoch, 'acc:', acc)
def main(args): random_seed = args.seed np.random.seed(random_seed) random.seed(random_seed) torch.manual_seed(random_seed) torch.cuda.manual_seed(random_seed) torch.cuda.manual_seed_all(random_seed) torch.backends.cudnn.deterministic = True # need to set to True as well print( 'Using {}\nTest on {}\nRandom Seed {}\nk_cc {}\nk_outlier {}\nevery n epoch {}\n' .format(args.net, args.dataset, args.seed, args.k_cc, args.k_outlier, args.every)) # -- training parameters -- if args.dataset == 'cifar10' or args.dataset == 'cifar100': num_epoch = 180 milestone = [int(x) for x in args.milestone.split(',')] batch_size = 128 elif args.dataset == 'pc': num_epoch = 90 milestone = [30, 60] batch_size = 128 else: ValueError('Invalid Dataset!') start_epoch = 0 num_workers = args.nworker weight_decay = 1e-4 gamma = 0.5 lr = 0.001 which_data_set = args.dataset # 'cifar100', 'cifar10', 'pc' noise_level = args.noise # noise level noise_type = args.type # "uniform", "asymmetric" train_val_ratio = 0.8 which_net = args.net # "resnet34", "resnet18", "pc" # -- denoising related parameters -- k_cc = args.k_cc k_outlier = args.k_outlier when_to_denoise = args.start_clean # starting from which epoch we denoise denoise_every_n_epoch = args.every # every n epoch, we perform denoising # -- specify dataset -- # data augmentation if which_data_set[:5] == 'cifar': transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) else: transform_train = None transform_test = None if which_data_set == 'cifar10': trainset = CIFAR10(root='./data', split='train', train_ratio=train_val_ratio, download=True, transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, worker_init_fn=_init_fn) valset = CIFAR10(root='./data', split='val', train_ratio=train_val_ratio, download=True, transform=transform_test) valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=num_workers) testset = CIFAR10(root='./data', split='test', download=True, transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers) num_class = 10 in_channel = 3 elif which_data_set == 'cifar100': trainset = CIFAR100(root='./data', split='train', train_ratio=train_val_ratio, download=True, transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, worker_init_fn=_init_fn) valset = CIFAR100(root='./data', split='val', train_ratio=train_val_ratio, download=True, transform=transform_test) valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=num_workers) testset = CIFAR100(root='./data', split='test', download=True, transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers) num_class = 100 in_channel = 3 elif which_data_set == 'pc': trainset = ModelNet40(split='train', train_ratio=train_val_ratio, num_ptrs=1024, random_jitter=True) trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, worker_init_fn=_init_fn, drop_last=True) valset = ModelNet40(split='val', train_ratio=train_val_ratio, num_ptrs=1024) valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=num_workers) testset = ModelNet40(split='test', num_ptrs=1024) testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers) num_class = 40 else: ValueError('Invalid Dataset!') print('train data size:', len(trainset)) print('validation data size:', len(valset)) print('test data size:', len(testset)) ntrain = len(trainset) # -- generate noise -- y_train = trainset.get_data_labels() y_train = np.array(y_train) noise_y_train = None keep_indices = None p = None if noise_type == 'none': pass else: if noise_type == "uniform": noise_y_train, p, keep_indices = noisify_with_P( y_train, nb_classes=num_class, noise=noise_level, random_state=random_seed) trainset.update_corrupted_label(noise_y_train) print("apply uniform noise") else: if which_data_set == 'cifar10': noise_y_train, p, keep_indices = noisify_cifar10_asymmetric( y_train, noise=noise_level, random_state=random_seed) elif which_data_set == 'cifar100': noise_y_train, p, keep_indices = noisify_cifar100_asymmetric( y_train, noise=noise_level, random_state=random_seed) elif which_data_set == 'pc': noise_y_train, p, keep_indices = noisify_modelnet40_asymmetric( y_train, noise=noise_level, random_state=random_seed) trainset.update_corrupted_label(noise_y_train) print("apply asymmetric noise") print("clean data num:", len(keep_indices)) print("probability transition matrix:\n{}".format(p)) # -- create log file -- file_name = '(' + which_data_set + '_' + which_net + ')' \ + 'type_' + noise_type + '_noise_' + str(noise_level) \ + '_k_cc_' + str(k_cc) + '_k_outlier_' + str(k_outlier) + '_start_' + str(when_to_denoise) \ + '_every_' + str(denoise_every_n_epoch) + '.txt' log_dir = check_folder('logs/logs_txt_' + str(random_seed)) file_name = os.path.join(log_dir, file_name) saver = open(file_name, "w") saver.write( 'noise type: {}\nnoise level: {}\nk_cc: {}\nk_outlier: {}\nwhen_to_apply_epoch: {}\n' .format(noise_type, noise_level, k_cc, k_outlier, when_to_denoise)) if noise_type != 'none': saver.write('total clean data num: {}\n'.format(len(keep_indices))) saver.write('probability transition matrix:\n{}\n'.format(p)) saver.flush() # -- set network, optimizer, scheduler, etc -- if which_net == 'resnet18': net = resnet18(in_channel=in_channel, num_classes=num_class) feature_size = 512 elif which_net == 'resnet34': net = resnet34(in_channel=in_channel, num_classes=num_class) feature_size = 512 elif which_net == 'pc': net = PointNetCls(k=num_class) feature_size = 256 else: ValueError('Invalid network!') optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = net.to(device) ################################################ exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=milestone, gamma=gamma) criterion = nn.NLLLoss() # since the output of network is by log softmax # -- misc -- best_acc = 0 best_epoch = 0 best_weights = None curr_trainloader = trainloader big_comp = set() patience = args.patience no_improve_counter = 0 # -- start training -- for epoch in range(start_epoch, num_epoch): train_correct = 0 train_loss = 0 train_total = 0 exp_lr_scheduler.step() net.train() print("current train data size:", len(curr_trainloader.dataset)) for _, (images, labels, _) in enumerate(curr_trainloader): if images.size( 0 ) == 1: # when batch size equals 1, skip, due to batch normalization continue images, labels = images.to(device), labels.to(device) outputs, features = net(images) log_outputs = torch.log_softmax(outputs, 1) loss = criterion(log_outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step() train_loss += loss.item() train_total += images.size(0) _, predicted = outputs.max(1) train_correct += predicted.eq(labels).sum().item() train_acc = train_correct / train_total * 100. cprint('epoch: {}'.format(epoch), 'white') cprint( 'train accuracy: {}\ntrain loss:{}'.format(train_acc, train_loss), 'yellow') # --- compute big connected components --- net.eval() features_all = torch.zeros(ntrain, feature_size).to(device) prob_all = torch.zeros(ntrain, num_class) labels_all = torch.zeros(ntrain, num_class) train_gt_labels = np.zeros(ntrain, dtype=np.uint64) train_pred_labels = np.zeros(ntrain, dtype=np.uint64) for _, (images, labels, indices) in enumerate(trainloader): images = images.to(device) outputs, features = net(images) softmax_outputs = torch.softmax(outputs, 1) features_all[indices] = features.detach() prob_all[indices] = softmax_outputs.detach().cpu() tmp_zeros = torch.zeros(labels.shape[0], num_class) tmp_zeros[torch.arange(labels.shape[0]), labels] = 1.0 labels_all[indices] = tmp_zeros train_gt_labels[indices] = labels.cpu().numpy().astype(np.int64) train_pred_labels[indices] = labels.cpu().numpy().astype(np.int64) if epoch >= when_to_denoise and ( epoch - when_to_denoise) % denoise_every_n_epoch == 0: cprint('\n>> Computing Big Components <<', 'white') labels_all = labels_all.numpy() train_gt_labels = train_gt_labels.tolist() train_pred_labels = np.squeeze(train_pred_labels).ravel().tolist() _, idx_of_comp_idx2 = calc_topo_weights_with_components_idx( ntrain, labels_all, features_all, train_gt_labels, train_pred_labels, k=k_cc, use_log=False, cp_opt=3, nclass=num_class) # --- update largest connected component --- cur_big_comp = list(set(range(ntrain)) - set(idx_of_comp_idx2)) big_comp = big_comp.union(set(cur_big_comp)) # --- remove outliers in largest connected component --- big_com_idx = list(big_comp) feats_big_comp = features_all[big_com_idx] labels_big_comp = np.array(train_gt_labels)[big_com_idx] knnG_list = calc_knn_graph(feats_big_comp, k=args.k_outlier) knnG_list = np.array(knnG_list) knnG_shape = knnG_list.shape knn_labels = labels_big_comp[knnG_list.ravel()] knn_labels = np.reshape(knn_labels, knnG_shape) majority, counts = mode(knn_labels, axis=-1) majority = majority.ravel() counts = counts.ravel() if args.zeta > 1.0: # use majority vote non_outlier_idx = np.where(majority == labels_big_comp)[0] outlier_idx = np.where(majority != labels_big_comp)[0] outlier_idx = np.array(list(big_comp))[outlier_idx] print(">> majority == labels_big_comp -> size: ", len(non_outlier_idx)) else: # zeta filtering non_outlier_idx = np.where((majority == labels_big_comp) & ( counts >= args.k_outlier * args.zeta))[0] print(">> zeta {}, then non_outlier_idx -> size: {}".format( args.zeta, len(non_outlier_idx))) outlier_idx = np.where(majority != labels_big_comp)[0] outlier_idx = np.array(list(big_comp))[outlier_idx] cprint(">> The number of outliers: {}".format(len(outlier_idx)), 'red') cprint( ">> The purity of outliers: {}".format( np.sum(y_train[outlier_idx] == noise_y_train[outlier_idx]) / float(len(outlier_idx))), 'red') big_comp = np.array(list(big_comp))[non_outlier_idx] big_comp = set(big_comp.tolist()) # --- construct updated dataset set, which contains the collected clean data --- if which_data_set == 'cifar10': trainset_ignore_noisy_data = CIFAR10( root='./data', split='train', train_ratio=train_val_ratio, download=True, transform=transform_train) elif which_data_set == 'cifar100': trainset_ignore_noisy_data = CIFAR100( root='./data', split='train', train_ratio=train_val_ratio, download=True, transform=transform_train) else: trainset_ignore_noisy_data = ModelNet40( split='train', train_ratio=train_val_ratio, num_ptrs=1024, random_jitter=True) trainloader_ignore_noisy_data = torch.utils.data.DataLoader( trainset_ignore_noisy_data, batch_size=batch_size, shuffle=True, num_workers=num_workers, worker_init_fn=_init_fn, drop_last=True) curr_trainloader = trainloader_ignore_noisy_data noisy_data_indices = list(set(range(ntrain)) - big_comp) trainset_ignore_noisy_data.update_corrupted_label(noise_y_train) trainset_ignore_noisy_data.ignore_noise_data(noisy_data_indices) clean_data_num = len(big_comp.intersection(set(keep_indices))) noise_data_num = len(big_comp) - clean_data_num print("Big Comp Number:", len(big_comp)) print("Found Noisy Data Number:", noise_data_num) print("Found True Data Number:", clean_data_num) # compute purity of the component cc_size = len(big_comp) equal = np.sum( noise_y_train[list(big_comp)] == y_train[list(big_comp)]) ratio = equal / float(cc_size) print("Purity of current component: {}".format(ratio)) noise_size = len(noisy_data_indices) equal = np.sum(noise_y_train[noisy_data_indices] == y_train[noisy_data_indices]) print("Purity of data outside component: {}".format( equal / float(noise_size))) saver.write('Purity {}\tsize{}\t'.format(ratio, cc_size)) # --- validation --- val_total = 0 val_correct = 0 net.eval() with torch.no_grad(): for _, (images, labels, _) in enumerate(valloader): images, labels = images.to(device), labels.to(device) outputs, _ = net(images) val_total += images.size(0) _, predicted = outputs.max(1) val_correct += predicted.eq(labels).sum().item() val_acc = val_correct / val_total * 100. if val_acc > best_acc: best_acc = val_acc best_epoch = epoch best_weights = copy.deepcopy(net.state_dict()) no_improve_counter = 0 else: no_improve_counter += 1 if no_improve_counter >= patience: print( '>> No improvement for {} epochs. Stop at epoch {}'.format( patience, epoch)) saver.write( '>> No improvement for {} epochs. Stop at epoch {}'.format( patience, epoch)) saver.write( '>> val epoch: {}\n>> current accuracy: {}\n'.format( epoch, val_acc)) saver.write('>> best accuracy: {}\tbest epoch: {}\n\n'.format( best_acc, best_epoch)) break cprint('val accuracy: {}'.format(val_acc), 'cyan') cprint( '>> best accuracy: {}\n>> best epoch: {}\n'.format( best_acc, best_epoch), 'green') saver.write('{}\n'.format(val_acc)) # -- testing cprint('>> testing <<', 'cyan') test_total = 0 test_correct = 0 net.load_state_dict(best_weights) net.eval() with torch.no_grad(): for _, (images, labels, _) in enumerate(testloader): images, labels = images.to(device), labels.to(device) outputs, _ = net(images) test_total += images.size(0) _, predicted = outputs.max(1) test_correct += predicted.eq(labels).sum().item() test_acc = test_correct / test_total * 100. cprint('>> test accuracy: {}'.format(test_acc), 'cyan') saver.write('>> test accuracy: {}\n'.format(test_acc)) # retest on the validation set, for sanity check cprint('>> validation <<', 'cyan') val_total = 0 val_correct = 0 net.eval() with torch.no_grad(): for _, (images, labels, _) in enumerate(valloader): images, labels = images.to(device), labels.to(device) outputs, _ = net(images) val_total += images.size(0) _, predicted = outputs.max(1) val_correct += predicted.eq(labels).sum().item() val_acc = val_correct / val_total * 100. cprint('>> validation accuracy: {}'.format(val_acc), 'cyan') saver.write('>> validation accuracy: {}'.format(val_acc)) saver.close() return test_acc
params.gpus) == 0 else batch_size * len(params.gpus) train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers) print('train dataset len: {}'.format(len(train_dataloader.dataset))) val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=num_workers) print('val dataset len: {}'.format(len(val_dataloader.dataset))) # models model = resnet18(pretrained=False, modelpath=model_path, num_classes=1000) # batch_size=120, 1GPU Memory < 7000M model.fc = nn.Linear(512, formation_num) # model = resnet101(pretrained=False, modelpath=model_path, num_classes=1000) # batch_size=60, 1GPU Memory > 9000M # model.fc = nn.Linear(512*4, 6) # optimizer trainable_vars = [ param for param in model.parameters() if param.requires_grad ] print("Training with sgd") params.optimizer = torch.optim.SGD(trainable_vars, lr=init_lr, momentum=momentum, weight_decay=weight_decay, nesterov=nesterov)
def testBlue(total_data, visualize=True): dataset = [ total_data[i] for i in np.random.choice(range(100000), 10, replace=False) ] if visualize else total_data images = [] summary = [[] for i in range(formation_num)] model = resnet18(pretrained=False, num_classes=1000) model.fc = nn.Linear(512, 1) model.eval() m = torch.load("./models/state_evaluation.pth") new_dict = model.state_dict().copy() for i in range(len(m)): new_dict[list(model.state_dict().keys())[i]] = m[list(m.keys())[i]] model.load_state_dict(new_dict) city_position = [[255] * city_grid[1] for i in range(city_grid[0])] for pos in CITY_POSITION: city_position[pos[1]][pos[0] - 1 - ocean_grid[1]] = 0 for data in dataset: base_position = data[0] row, col = base_position.shape for i in range(row): for j in range(col): if base_position[i][j]: assert (base_position[i][j] <= MAX_MISSILE) base_position[i][j] = (MAX_MISSILE - base_position[i][j] ) * int(255 / MAX_MISSILE) else: base_position[i][j] = 255 img = np.hstack( (base_position, 255 - 255 * data[1], np.array(city_position))) h, w = img.shape if (224 - h) % 2 != 0 or (224 - w) % 2 != 0: padding = (int((224 - w) / 2), 224 - int( (224 - w) / 2) - w, int((224 - h) / 2), 224 - int( (224 - h) / 2) - h) else: padding = (int((224 - w) / 2), int((224 - h) / 2)) transforms1 = T.Compose([T.Pad(padding, fill=255), T.ToTensor()]) transforms2 = T.Compose([ T.Pad((0, 29), fill=255), ]) img = Image.fromarray(img.astype('uint8')).convert('L') eval_value = model(Variable(torch.unsqueeze(transforms1(img), 0))) eval_value = round(float(eval_value), 4) if visualize: image = transforms2(img) images.append([image, str(eval_value) + "/" + str(data[2])]) else: # print(data[1], type(data[1])) forms = [f.tolist() for f in list(formations.values())] i = forms.index(data[1].tolist()) summary[i].append(abs(eval_value - float(data[2]))) # print(images) if visualize: plt.figure("State Evaluation") for i in range(1, len(images) + 1): plt.subplot(2, 5, i) plt.title(images[i - 1][1]) plt.imshow(images[i - 1][0]) # plt.axis('off') plt.show() if input() == " ": testBlue(total_data, visualize) else: avg, min_val, max_val = [], [], [] total = 0 for diff in summary: total += sum(diff) avg.append(round(sum(diff) / len(diff), 4)) max_val.append(round(max(diff), 4)) min_val.append(round(min(diff), 4)) print(round(total / 10000), 4) index = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] plt.bar(left=index, height=min_val, label="Min Absolute Difference") plt.bar(left=index, height=avg, bottom=min_val, label="Avg Absolute Difference") plt.bar(left=index, height=max_val, bottom=avg, label="Max Absolute Difference") plt.xlabel('Formation') plt.ylabel('Absolute Difference Between Evaluation and Ground Truth') plt.legend(loc='best') plt.show()
def create_model(Model, num_classes): if Model == "ResNet18": model = resnet18(pretrained=False) fc_features = model.fc.in_features model.fc = nn.Linear(fc_features, num_classes) model = model.cuda() elif Model == "ResNet34": model = resnet34(pretrained=False).cuda() fc_features = model.fc.in_features model.fc = nn.Linear(fc_features, num_classes) model = model.cuda() elif Model == "ResNet50": model = resnet50(pretrained=False).cuda() fc_features = model.fc.in_features model.fc = nn.Linear(fc_features, num_classes) model = model.cuda() elif Model == "ResNet101": model = resnet101(pretrained=False).cuda() fc_features = model.fc.in_features model.fc = nn.Linear(fc_features, num_classes) model = model.cuda() elif Model == "MobileNet_v2": model = mobilenet_v2(num_classes=num_classes, pretrained=False).cuda() elif Model == "Mobilenetv3": model = mobilenetv3(n_class=num_classes, pretrained=False).cuda() elif Model == "shufflenet_v2_x0_5": model = shufflenet_v2_x0_5(pretrained=False, num_classes=num_classes).cuda() elif Model == "shufflenet_v2_x1_0": model = shufflenet_v2_x1_0(pretrained=False, num_classes=num_classes).cuda() elif Model == "shufflenet_v2_x1_5": model = shufflenet_v2_x1_5(pretrained=False, num_classes=num_classes).cuda() elif Model == "shufflenet_v2_x1_5": model = shufflenet_v2_x2_0(pretrained=False, num_classes=num_classes).cuda() elif "efficientnet" in Model: model = EfficientNet.from_pretrained(Model, num_classes=num_classes, pretrained=False).cuda() elif Model == "inception_v3": model = inception_v3(pretrained=False, num_classes=num_classes).cuda() elif Model == "mnasnet0_5": model = mnasnet0_5(pretrained=False, num_classes=num_classes).cuda() elif Model == "vgg11": model = vgg11(pretrained=False, num_classes=num_classes).cuda() elif Model == "vgg11_bn": model = vgg11_bn(pretrained=False, num_classes=num_classes).cuda() elif Model == "vgg19": model = vgg19(pretrained=False, num_classes=num_classes).cuda() elif Model == "densenet121": model = densenet121(pretrained=False).cuda() elif Model == "ResNeXt29_32x4d": model = ResNeXt29_32x4d(num_classes=num_classes).cuda() elif Model == "ResNeXt29_2x64d": model = ResNeXt29_2x64d(num_classes=num_classes).cuda() else: print("model error") # input = torch.randn(1, 3, 32, 32).cuda() # flops, params = profile(model, inputs=(input,)) # flops, params = clever_format([flops, params], "%.3f") # print("------------------------------------------------------------------------") # print(" ",Model) # print( " flops:", flops, " params:", params) # print("------------------------------------------------------------------------") return model
def main(args): random_seed = int(np.random.choice(range(1000), 1)) np.random.seed(random_seed) random.seed(random_seed) torch.manual_seed(random_seed) torch.cuda.manual_seed(random_seed) torch.cuda.manual_seed_all(random_seed) torch.backends.cudnn.deterministic = True arg_which_net = args.network arg_dataset = args.dataset arg_epoch_start = args.epoch_start lr = args.lr arg_gpu = args.gpu arg_num_gpu = args.n_gpus arg_every_n_epoch = args.every_n_epoch # interval to perform the correction arg_epoch_update = args.epoch_update # the epoch to start correction (warm-up period) arg_epoch_interval = args.epoch_interval # interval between two update of A noise_level = args.noise_level noise_type = args.noise_type # "uniform", "asymmetric", "none" train_val_ratio = 0.9 which_net = arg_which_net # "cnn" "resnet18" "resnet34" "preact_resnet18" "preact_resnet34" "preact_resnet101" "pc" num_epoch = args.n_epochs # Total training epochs print('Using {}\nTest on {}\nRandom Seed {}\nevery n epoch {}\nStart at epoch {}'. format(arg_which_net, arg_dataset, random_seed, arg_every_n_epoch, arg_epoch_start)) # -- training parameters if arg_dataset == 'mnist': milestone = [30, 60] batch_size = 64 in_channels = 1 elif arg_dataset == 'cifar10': milestone = [60, 180] batch_size = 128 in_channels = 1 elif arg_dataset == 'cifar100': milestone = [60, 180] batch_size = 128 in_channels = 1 elif arg_dataset == 'pc': milestone = [30, 60] batch_size = 128 start_epoch = 0 num_workers = 1 #gamma = 0.5 # -- specify dataset # data augmentation if arg_dataset == 'cifar10': transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) elif arg_dataset == 'cifar100': transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.507, 0.487, 0.441), (0.507, 0.487, 0.441)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.507, 0.487, 0.441), (0.507, 0.487, 0.441)), ]) elif arg_dataset == 'mnist': transform_train = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)), ]) else: transform_train = None transform_test = None if arg_dataset == 'cifar10': trainset = CIFAR10(root='./data', split='train', train_ratio=train_val_ratio, trust_ratio=0, download=True, transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, worker_init_fn=_init_fn) valset = CIFAR10(root='./data', split='val', train_ratio=train_val_ratio, trust_ratio=0, download=True, transform=transform_test) valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=num_workers) testset = CIFAR10(root='./data', split='test', download=True, transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers) images_size = [3, 32, 32] num_class = 10 in_channel = 3 elif arg_dataset == 'cifar100': trainset = CIFAR100(root='./data', split='train', train_ratio=train_val_ratio, trust_ratio=0, download=True, transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers,worker_init_fn=_init_fn) valset = CIFAR100(root='./data', split='val', train_ratio=train_val_ratio, trust_ratio=0, download=True, transform=transform_test) valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=num_workers) testset = CIFAR100(root='./data', split='test', download=True, transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers) num_class = 100 in_channel = 3 elif arg_dataset == 'mnist': trainset = MNIST(root='./data', split='train', train_ratio=train_val_ratio, download=True, transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers,worker_init_fn=_init_fn) valset = MNIST(root='./data', split='val', train_ratio=train_val_ratio, download=True, transform=transform_test) valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=num_workers) testset = MNIST(root='./data', split='test', download=True, transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers) num_class = 10 in_channel = 1 elif arg_dataset == 'pc': trainset = ModelNet40(split='train', train_ratio=train_val_ratio, num_ptrs=1024, random_jitter=True) trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, worker_init_fn=_init_fn, drop_last=True) valset = ModelNet40(split='val', train_ratio=train_val_ratio, num_ptrs=1024) valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=num_workers) testset = ModelNet40(split='test', num_ptrs=1024) testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers) num_class = 40 print('train data size:', len(trainset)) print('validation data size:', len(valset)) print('test data size:', len(testset)) eps = 1e-6 # This is the epsilon used to soft the label (not the epsilon in the paper) ntrain = len(trainset) # -- generate noise -- # y_train is ground truth labels, we should not have any access to this after the noisy labels are generated # algorithm after y_tilde is generated has nothing to do with y_train y_train = trainset.get_data_labels() y_train = np.array(y_train) noise_y_train = None keep_indices = None p = None if(noise_type == 'none'): pass else: if noise_type == "uniform": noise_y_train, p, keep_indices = noisify_with_P(y_train, nb_classes=num_class, noise=noise_level, random_state=random_seed) trainset.update_corrupted_label(noise_y_train) noise_softlabel = torch.ones(ntrain, num_class)*eps/(num_class-1) noise_softlabel.scatter_(1, torch.tensor(noise_y_train.reshape(-1, 1)), 1-eps) trainset.update_corrupted_softlabel(noise_softlabel) print("apply uniform noise") else: if arg_dataset == 'cifar10': noise_y_train, p, keep_indices = noisify_cifar10_asymmetric(y_train, noise=noise_level, random_state=random_seed) elif arg_dataset == 'cifar100': noise_y_train, p, keep_indices = noisify_cifar100_asymmetric(y_train, noise=noise_level, random_state=random_seed) elif arg_dataset == 'mnist': noise_y_train, p, keep_indices = noisify_mnist_asymmetric(y_train, noise=noise_level, random_state=random_seed) elif arg_dataset == 'pc': noise_y_train, p, keep_indices = noisify_modelnet40_asymmetric(y_train, noise=noise_level, random_state=random_seed) trainset.update_corrupted_label(noise_y_train) noise_softlabel = torch.ones(ntrain, num_class) * eps / (num_class - 1) noise_softlabel.scatter_(1, torch.tensor(noise_y_train.reshape(-1, 1)), 1 - eps) trainset.update_corrupted_softlabel(noise_softlabel) print("apply asymmetric noise") print("clean data num:", len(keep_indices)) print("probability transition matrix:\n{}".format(p)) # -- create log file file_name = '[' + arg_dataset + '_' + which_net + ']' \ + 'type:' + noise_type + '_' + 'noise:' + str(noise_level) + '_' \ + '_' + 'start:' + str(arg_epoch_start) + '_' \ + 'every:' + str(arg_every_n_epoch) + '_'\ + 'time:' + str(datetime.datetime.now()) + '.txt' log_dir = check_folder('new_logs/logs_txt_' + str(random_seed)) file_name = os.path.join(log_dir, file_name) saver = open(file_name, "w") saver.write('noise type: {}\nnoise level: {}\nwhen_to_apply_epoch: {}\n'.format( noise_type, noise_level, arg_epoch_start)) if noise_type != 'none': saver.write('total clean data num: {}\n'.format(len(keep_indices))) saver.write('probability transition matrix:\n{}\n'.format(p)) saver.flush() # -- set network, optimizer, scheduler, etc if which_net == "cnn": net_trust = CNN9LAYER(input_channel=in_channel, n_outputs=num_class) net = CNN9LAYER(input_channel=in_channel, n_outputs=num_class) net.apply(weight_init) feature_size = 128 elif which_net == 'resnet18': net_trust = resnet18(in_channel=in_channel, num_classes=num_class) net = resnet18(in_channel=in_channel, num_classes=num_class) feature_size = 512 elif which_net == 'resnet34': net_trust = resnet34(in_channel=in_channel, num_classes=num_class) net = resnet34(in_channel=in_channel, num_classes=num_class) feature_size = 512 elif which_net == 'preact_resnet18': net_trust = preact_resnet18(num_classes=num_class, num_input_channels=in_channel) net = preact_resnet18(num_classes=num_class, num_input_channels=in_channel) feature_size = 256 elif which_net == 'preact_resnet34': net_trust = preact_resnet34(num_classes=num_class, num_input_channels=in_channel) net = preact_resnet34(num_classes=num_class, num_input_channels=in_channel) feature_size = 256 elif which_net == 'preact_resnet101': net_trust = preact_resnet101() net = preact_resnet101() feature_size = 256 elif which_net == 'pc': net_trust = PointNetCls(k=num_class) net = PointNetCls(k=num_class) feature_size = 256 else: ValueError('Invalid network!') opt_gpus = [i for i in range(arg_gpu, arg_gpu+int(arg_num_gpu))] if len(opt_gpus) > 1: print("Using ", len(opt_gpus), " GPUs") os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(x) for x in opt_gpus) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(device) if len(opt_gpus) > 1: net_trust = torch.nn.DataParallel(net_trust) net = torch.nn.DataParallel(net) net_trust.to(device) net.to(device) # net.apply(conv_init) # ------------------------- Initialize Setting ------------------------------ best_acc = 0 best_epoch = 0 patience = 50 no_improve_counter = 0 A = 1/num_class*torch.ones(ntrain, num_class, num_class, requires_grad=False).float().to(device) h = np.zeros([ntrain, num_class]) criterion_1 = nn.NLLLoss() pred_softlabels = np.zeros([ntrain, arg_every_n_epoch, num_class], dtype=np.float) train_acc_record = [] clean_train_acc_record = [] noise_train_acc_record = [] val_acc_record = [] recovery_record = [] noise_ytrain = copy.copy(noise_y_train) #noise_ytrain = torch.tensor(noise_ytrain).to(device) cprint("================ Clean Label... ================", "yellow") for epoch in range(num_epoch): # Add some modification here train_correct = 0 train_loss = 0 train_total = 0 delta = 1.2 + 0.02*max(epoch - arg_epoch_update + 1, 0) clean_train_correct = 0 noise_train_correct = 0 optimizer_trust = RAdam(net_trust.parameters(), lr=learning_rate(lr, epoch), weight_decay=5e-4) #optimizer_trust = optim.SGD(net_trust.parameters(), lr=learning_rate(lr, epoch), weight_decay=5e-4, # nesterov=True, momentum=0.9) net_trust.train() # Train with noisy data for i, (images, labels, softlabels, indices) in enumerate(tqdm(trainloader, ncols=100, ascii=True)): if images.size(0) == 1: # when batch size equals 1, skip, due to batch normalization continue images, labels, softlabels = images.to(device), labels.to(device), softlabels.to(device) outputs, features = net_trust(images) log_outputs = torch.log_softmax(outputs, 1).float() # arg_epoch_start : epoch start to introduce loss retro # arg_epoch_interval : epochs between two updating of A if epoch in [arg_epoch_start-1, arg_epoch_start+arg_epoch_interval-1]: h[indices] = log_outputs.detach().cpu() normal_outputs = torch.softmax(outputs, 1) if epoch >= arg_epoch_start: # use loss_retro + loss_ce A_batch = A[indices].to(device) loss = sum([-A_batch[i].matmul(softlabels[i].reshape(-1, 1).float()).t().matmul(log_outputs[i]) for i in range(len(indices))]) / len(indices) + \ criterion_1(log_outputs, labels) else: # use loss_ce loss = criterion_1(log_outputs, labels) optimizer_trust.zero_grad() loss.backward() optimizer_trust.step() #arg_every_n_epoch : rolling windows to get eta_tilde if epoch >= (arg_epoch_update - arg_every_n_epoch): pred_softlabels[indices, epoch % arg_every_n_epoch, :] = normal_outputs.detach().cpu().numpy() train_loss += loss.item() train_total += images.size(0) _, predicted = outputs.max(1) train_correct += predicted.eq(labels).sum().item() # For monitoring purpose, comment the following line out if you have your own dataset that doesn't have ground truth label train_label_clean = torch.tensor(y_train)[indices].to(device) train_label_noise = torch.tensor(noise_ytrain[indices]).to(device) clean_train_correct += predicted.eq(train_label_clean).sum().item() noise_train_correct += predicted.eq(train_label_noise).sum().item() # acc wrt the original noisy labels train_acc = train_correct / train_total * 100 clean_train_acc = clean_train_correct/train_total*100 noise_train_acc = noise_train_correct/train_total*100 print(" Train Epoch: [{}/{}] \t Training Acc wrt Corrected {:.3f} \t Train Acc wrt True {:.3f} \t Train Acc wrt Noise {:.3f}". format(epoch, num_epoch, train_acc, clean_train_acc, noise_train_acc)) # updating A if epoch in [arg_epoch_start - 1, arg_epoch_start + 40 - 1]: cprint("+++++++++++++++++ Updating A +++++++++++++++++++", "magenta") unsolved = 0 infeasible = 0 y_soft = trainset.get_data_softlabel() with torch.no_grad(): for i in tqdm(range(ntrain), ncols=100, ascii=True): try: result, A_opt = updateA(y_soft[i], h[i], rho=0.9) except: A[i] = A[i] unsolved += 1 continue if (result == np.inf): A[i] = A[i] infeasible += 1 else: A[i] = torch.tensor(A_opt) print(A[0]) print("Unsolved points: {} | Infeasible points: {}".format(unsolved, infeasible)) # applying lRT scheme # args_epoch_update : epoch to update labels if epoch >= arg_epoch_update: y_tilde = trainset.get_data_labels() pred_softlabels_bar = pred_softlabels.mean(1) clean_labels, clean_softlabels = lrt_flip_scheme(pred_softlabels_bar, y_tilde, delta) trainset.update_corrupted_softlabel(clean_softlabels) trainset.update_corrupted_label(clean_softlabels.argmax(1)) # validation if not (epoch % 5): val_total = 0 val_correct = 0 net_trust.eval() with torch.no_grad(): for i, (images, labels, _, _) in enumerate(valloader): images, labels = images.to(device), labels.to(device) outputs, _ = net_trust(images) val_total += images.size(0) _, predicted = outputs.max(1) val_correct += predicted.eq(labels).sum().item() val_acc = val_correct / val_total * 100 train_acc_record.append(train_acc) val_acc_record.append(val_acc) clean_train_acc_record.append(clean_train_acc) noise_train_acc_record.append(noise_train_acc) recovery_acc = np.sum(trainset.get_data_labels() == y_train) / ntrain recovery_record.append(recovery_acc) if val_acc > best_acc: best_acc = val_acc best_epoch = epoch no_improve_counter = 0 else: no_improve_counter += 1 if no_improve_counter >= patience: print('>> No improvement for {} epochs. Stop at epoch {}'.format(patience, epoch)) saver.write('>> No improvement for {} epochs. Stop at epoch {}'.format(patience, epoch)) saver.write('>> val epoch: {}\n>> current accuracy: {}%\n'.format(epoch, val_acc)) saver.write('>> best accuracy: {}\tbest epoch: {}\n\n'.format(best_acc, best_epoch)) break cprint('val accuracy: {}'.format(val_acc), 'cyan') cprint('>> best accuracy: {}\n>> best epoch: {}\n'.format(best_acc, best_epoch), 'green') cprint('>> final recovery rate: {}\n'.format(recovery_acc), 'green') saver.write('>> val epoch: {}\n>> current accuracy: {}%\n'.format(epoch, val_acc)) saver.write("outputs: {}\n".format(normal_outputs)) saver.write('>> best accuracy: {}\tbest epoch: {}\n\n'.format(best_acc, best_epoch)) saver.write( '>> final recovery rate: {}%\n'.format(np.sum(trainset.get_data_labels() == y_train) / ntrain * 100)) saver.flush() # If want to train the neural network again with corrected labels and original loss_ce again # set args.two_stage to True print("Use Two-Stage Model {}".format(args.two_stage)) if args.two_stage==True: criterion_2 = nn.NLLLoss() best_acc = 0 best_epoch = 0 patience = 50 no_improve_counter = 0 cprint("================ Normal Training ================", "yellow") for epoch in range(num_epoch): # Add some modification here train_correct = 0 train_loss = 0 train_total = 0 optimizer_trust = optim.SGD(net.parameters(), momentum=0.9, nesterov=True, lr=learning_rate(lr, epoch), weight_decay=5e-4) net.train() for i, (images, labels, softlabels, indices) in enumerate(tqdm(trainloader, ncols=100, ascii=True)): if images.size(0) == 1: # when batch size equals 1, skip, due to batch normalization continue images, labels, softlabels = images.to(device), labels.to(device), softlabels.to(device) outputs, features = net(images) log_outputs = torch.log_softmax(outputs, 1).float() loss = criterion_2(log_outputs, labels) optimizer_trust.zero_grad() loss.backward() optimizer_trust.step() train_loss += loss.item() train_total += images.size(0) _, predicted = outputs.max(1) train_correct += predicted.eq(labels).sum().item() train_acc = train_correct / train_total * 100 print(" Train Epoch: [{}/{}] \t Training Accuracy {}%".format(epoch, num_epoch, train_acc)) if not (epoch % 5): val_total = 0 val_correct = 0 net_trust.eval() with torch.no_grad(): for i, (images, labels, _, _) in enumerate(valloader): images, labels = images.to(device), labels.to(device) outputs, _ = net(images) val_total += images.size(0) _, predicted = outputs.max(1) val_correct += predicted.eq(labels).sum().item() val_acc = val_correct / val_total * 100 if val_acc > best_acc: best_acc = val_acc best_epoch = epoch no_improve_counter = 0 else: no_improve_counter += 1 if no_improve_counter >= patience: print('>> No improvement for {} epochs. Stop at epoch {}'.format(patience, epoch)) saver.write('>> No improvement for {} epochs. Stop at epoch {}'.format(patience, epoch)) saver.write('>> val epoch: {}\n>> current accuracy: {}%\n'.format(epoch, val_acc)) saver.write('>> best accuracy: {}\tbest epoch: {}\n\n'.format(best_acc, best_epoch)) break cprint('val accuracy: {}'.format(val_acc), 'cyan') cprint('>> best accuracy: {}\n>> best epoch: {}\n'.format(best_acc, best_epoch), 'green') cprint('>> final recovery rate: {}\n'.format(np.sum(trainset.get_data_labels() == y_train) / ntrain), 'green') cprint("================ Start Testing ================", "yellow") test_total = 0 test_correct = 0 net_trust.eval() for i, (images, labels, softlabels, indices) in enumerate(testloader): if images.shape[0] == 1: continue images, labels = images.to(device), labels.to(device) outputs, _ = net_trust(images) test_total += images.shape[0] test_correct += outputs.argmax(1).eq(labels).sum().item() test_acc = test_correct/test_total*100 print("Final test accuracy {} %".format(test_correct/test_total*100)) return test_acc
def __init__(self, ver_dim, seg_dim, fcdim=256, s8dim=128, s4dim=64, s2dim=32, raw_dim=32, inp_dim=3): super(Resnet18_8s, self).__init__() # Load the pretrained weights, remove avg pool # layer and get the output stride of 8 resnet18_8s = resnet18( inp_dim=inp_dim, fully_conv=True, pretrained=True, output_stride=8, remove_avg_pool_layer=True, ) self.ver_dim = ver_dim self.seg_dim = seg_dim # Randomly initialize the 1x1 Conv scoring layer resnet18_8s.fc = nn.Sequential( nn.Conv2d(resnet18_8s.inplanes, fcdim, 3, 1, 1, bias=False), nn.BatchNorm2d(fcdim), nn.ReLU(True), ) self.resnet18_8s = resnet18_8s # x8s->128 self.conv8s = nn.Sequential( nn.Conv2d(128 + fcdim, s8dim, 3, 1, 1, bias=False), nn.BatchNorm2d(s8dim), nn.LeakyReLU(0.1, True), ) self.up8sto4s = nn.UpsamplingBilinear2d(scale_factor=2) # x4s->64 self.conv4s = nn.Sequential( nn.Conv2d(64 + s8dim, s4dim, 3, 1, 1, bias=False), nn.BatchNorm2d(s4dim), nn.LeakyReLU(0.1, True), ) self.up4sto2s = nn.UpsamplingBilinear2d(scale_factor=2) # x2s->64 self.conv2s = nn.Sequential( nn.Conv2d(64 + s4dim, s2dim, 3, 1, 1, bias=False), nn.BatchNorm2d(s2dim), nn.LeakyReLU(0.1, True), ) self.up2storaw = nn.UpsamplingBilinear2d(scale_factor=2) self.convraw = nn.Sequential( nn.Conv2d(inp_dim + s2dim, raw_dim, 3, 1, 1, bias=False), nn.BatchNorm2d(raw_dim), nn.LeakyReLU(0.1, True), nn.Conv2d(raw_dim, self.seg_dim + self.ver_dim, 1, 1), )
def __init__(self,mode): super(buildlModel, self).__init__() self.backbone=resnet18(pretrained=True) self.neck=AdjustLayer(256,256) self.rpn=RPN(256,256,1) self.mode=mode
def __init__(self): self.img_size = constant.IMG_SIZE self.noise_dim = config.generator["noise_dim"] self.generator = nn.DataParallel( TP_GAN.Generator( noise_dim=config.generator["noise_dim"], encode_feature_dim=config.generator["encode_feature_dim"], encode_predict_dim=config.generator["encode_predict_dim"], use_batchnorm=config.generator["use_batchnorm"], use_residual_block=config.generator["use_residual_block"])) self.discriminator = nn.DataParallel( TP_GAN.Discriminator(config.discriminator["use_batchnorm"])) self.lr = config.settings["init_lr"] self.optimizer_G = torch.optim.Adam(self.generator.parameters(), lr=self.lr) self.optimizer_D = torch.optim.Adam(self.discriminator.parameters(), lr=self.lr) self.lr_scheduler_G = torch.optim.lr_scheduler.MultiStepLR( self.optimizer_G, milestones=[], gamma=0.1) self.lr_scheduler_D = torch.optim.lr_scheduler.MultiStepLR( self.optimizer_D, milestones=[], gamma=0.1) self.epoch = config.settings["epoch"] self.batch_size = config.settings["batch_size"] self.train_dataloader = torch.utils.data.DataLoader( dataset.train_dataloader(config.path["img_test"], batch_size=self.batch_size, shuffle=True, num_workers=8, pin_memory=True)) self.feature_extract_network_resnet18 = nn.DataParallel( resnet.resnet18( pretrained=True, num_classes=config.generator["encode_predict_dim"])) self.l1_loss = nn.L1Loss() self.MSE_loss = nn.MSELoss() self.cross_entropy = nn.CrossEntropyLoss() if USE_CUDA: self.generator = self.generator.cuda() self.discriminator = self.discriminator.cuda() self.l1_loss = self.l1_loss.cuda() self.MSE_loss = self.MSE_loss.cuda() self.cross_entropy = self.cross_entropy.cuda() def backward_D(self, img_predict, img_frontal): set_requires_grad(self.discriminator, True) adversarial_D_loss = -torch.mean( discriminator(img_frontal)) + torch.mean( discriminator(img_predict)) factor = torch.rand(img_frontal.shape[0], 1, 1, 1).expand(img_frontal.size()) # interpolated_value = Variable(factor * img_predict.data + (1.0 - factor) * img_frontal.data, requires_grad=True) output = self.discriminator(interpolated_value) # WGAN-GP loss gradient = torch.autograd.grad(outputs=output, inputs=interpolated_value, grad_outputs=torch.ones( output.size()), retain_graph=True, create_graph=True, only_inputs=True)[0] gradient = gradient.view(output.shape[0], -1) if USE_CUDA: gradient = gradient.cuda() gp_loss = torch.mean((torch.norm(gradient, p=2) - 1)**2) loss_D = adversarial_D_loss + config.loss[ "weight_gradient_penalty"] * gp_loss self.optimizer_D.zero_grad() loss_D.backward() self.optimizer_D.step() return loss_D def backward_G(self, inputs, outputs): set_requires_grad(self.discriminator, False) # adversarial loss adversarial_G_loss = -torch.mean( self.discriminator(outputs["img128_predict"])) # pixel wise loss pixelwise_loss_128 = self.l1_loss(inputs["img128_frontal"], outputs["img128_predict"]) pixelwise_loss_64 = self.l1_loss(inputs["img64_frontal"], outputs["img64_predict"]) pixelwise_loss_32 = self.l1_loss(inputs["img32_frontal"], outputs["img32_predict"]) pixelwise_global_loss = config.loss["pixelwise_weight_128"] * pixelwise_weight_128 + \ config.loss["pixelwise_weight_64"] * pixelwise_loss_64 + config.loss["pixelwise_weight_32"] * pixelwise_loss_32 left_eye_loss = self.l1_loss(inputs["left_eye_frontal"], outputs["left_eye_predict"]) right_eye_loss = self.l1_loss(inputs["right_eye_frontal"], outputs["right_eye_predict"]) nose_loss = self.l1_loss(inputs["nose_frontal"], outputs["nose_predict"]) mouth_loss = self.l1_loss(inputs["mouth_frontal"], outputs["mouth_predict"]) pixel_local_loss = left_eye_loss + right_eye_loss + nose_loss + mouth_loss # symmetry loss img128 = outputs["img128_predict"] img64 = outputs["img64_predict"] img32 = outputs["img32_predict"] inv_idx128 = torch.arange(img128.size()[3] - 1, -1, -1).long() inv_idx64 = torch.arange(img64.size()[3] - 1, -1, -1).long() inv_idx32 = torch.arange(img32.size()[3] - 1, -1, -1).long() if USE_CUDA: inv_idx128 = inv_idx128.cuda() inv_idx64 = inv_idx64.cuda() inv_idx32 = inv_idx32.cuda() img128_flip = img128.index_select(3, Variable(inv_idx128)) img64_flip = img64.index_select(3, Variable(inv_idx64)) img32_flip = img32.index_select(3, Variable(inv_idx32)) img128_flip.detach_() img64_flip.detach_() img32_flip.detach_() symmetry_loss_128 = self.l1_loss(img128, img128_flip) symmetry_loss_64 = self.l1_loss(img64, img64_flip) symmetry_loss_32 = self.l1_loss(img32, img32_flip) symmetry_loss = config.loss["symmetry_weight_128"] * symmetry_weight_128 + \ config.loss["symmetry_weight_64"] * symmetry_weight_64 + config.loss["symmetry_weight_32"] * symmetry_weight_32 # identity preserving loss feature_frontal = self.feature_extract_network_resnet18( inputs["img128_frontal"]) feature_predict = self.feature_extract_network_resnet18( outputs["img128_predict"]) identity_preserving_loss = self.MSE_loss(feature_frontal, feature_predict) # total variation loss for regularization img128 = outputs["img128_predict"] total_variation_loss = torch.mean( torch.abs(img128[:, :, :-1, :] - img128[:, :, 1:, :])) + torch.mean( torch.abs(img128[:, :, :, :-1], img128[:, :, :, 1:])) # cross entropy loss cross_entropy_loss = self.cross_entropy_loss( outputs["encode_predict"], inputs["label"]) # synthesized loss synthesized_loss = config.loss["pixelwise_global_weight"] * pixelwise_global_loss + config.loss["pixel_local_weight"] * pixel_local_loss + \ config.loss["symmetry_weight"] * symmetry_loss + config.loss["adversarial_G_weight"] * adversarial_G_loss + \ config.loss["identity_preserving_weight"] * identity_preserving_loss + config.loss["total_variation_weight"] * total_variation_loss loss_G = synthesized_loss + config.loss[ "cross_entropy_weight"] * cross_entropy_loss self.optimizer_G.zero_grad() loss_G.backward() optimizer_G.step() def train(self): self.generator.train() self.discriminator.train() self.feature_extract_network_resnet18.eval() for cur_epoch in range(self.epoch): print("\ncurrent epoch number: %d" % cur_epoch) self.lr_scheduler_G.step() self.lr = self.lr_scheduler_G.get_lr()[0] self.lr_scheduler_D.step() self.lr = self.lr_scheduler_D.get_lr()[0] for batch_index, inputs in enumerate(self.train_dataloader): noise = Variable( torch.FloatTensor( np.random.uniform( -1, 1, (self.batch_size, self.noise_dim)))) for k, v in inputs: if USE_CUDA: v = v.cuda() v = Variable(v, requires_grad=False) if USE_CUDA: noise = noise.cuda() generator_output = generator(inputs["img"], inputs["left_eye"], inputs["right_eye"], inputs["nose"], inputs["mouth"], noise) # backward backward_D(generator_output["img128_predict"].detach(), inputs["img128_frontal"]) backward_G(generator_output, inputs) save_model() def save_model(self): torch.save(self.generator.cpu().state_dict(), config.path["generator__path"]) torch.save(self.discriminator.cpu().state_dict(), config.path["discriminator_save_path"])