def __init__(self, trunk_name, output_stride=8, pretrained=True): super(get_resnet, self).__init__() if trunk_name == "resnet18": resnet = resnet18(pretrained=pretrained) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk_name == "resnet34": resnet = resnet34(pretrained=pretrained) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk_name == "resnet50": resnet = resnet50(pretrained=pretrained) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk_name == "resnet101": resnet = resnet101(pretrained=pretrained) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk_name == "resnet152": resnet = resnet152(pretrained=pretrained) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk_name == "resnext101_32x8d": resnet = resnext101_32x8d(pretrained=pretrained) resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) elif trunk_name == "se_resnet50": resnet = se_resnext50_32x4d(pretrained=pretrained) elif trunk_name == "se_resnet101": resnet = se_resnext101_32x4d(pretrained=pretrained) else: raise KeyError("[*] Not a valid network arch") self.layer0 = resnet.layer0 self.layer1, self.layer2, self.layer3, self.layer4 = \ resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4 if output_stride == 8: for n, m in self.layer3.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) elif output_stride == 16: for n, m in self.layer4.named_modules(): if 'conv2' in n: m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) elif 'downsample.0' in n: m.stride = (1, 1) else: raise KeyError( "[*] Unsupported output_stride {}".format(output_stride))
def __init__(self, num_classes=256, input_bits=8, embed_dim=8, seq_len=100, backbone='resnet50'): super(CNN, self).__init__() self.input_bits = input_bits self.seq_len = seq_len # self.embed = nn.Embedding(256, 256) self.embed = nn.Embedding(2 ** input_bits, embed_dim) self.backbone = { 'resnet18': resnet18(n_classes=num_classes, input_channels=embed_dim), 'resnet34': resnet34(n_classes=num_classes, input_channels=embed_dim), 'resnet50': resnet50(n_classes=num_classes, input_channels=embed_dim), 'resnet101': resnet101(n_classes=num_classes, input_channels=embed_dim), 'resnet152': resnet152(n_classes=num_classes, input_channels=embed_dim), }.get(backbone, 'resnet50')
def __init__(self, args, train_transform=None, test_transform=None, val_transform=None): self.args = args self.train_transform = train_transform self.test_transform = test_transform self.val_transform = val_transform self.loss_lams = torch.zeros(args.num_classes, args.num_classes, dtype=torch.float32).cuda() self.loss_lams[:, :] = 1. / args.num_classes self.loss_lams.requires_grad = False if self.args.arch == 'resnet50': self.net = resnet.resnet50(num_classes=args.num_classes) elif self.args.arch == 'res2net50': self.net = res2net.res2net50_26w_4s(num_classes=args.num_classes) elif self.args.arch == 'resnet101': self.net = resnet.resnet101() elif self.args.arch == 'resnet18': self.net = resnet.resnet18() elif self.args.arch == 'resnet34': self.net = resnet.resnet34() elif self.args.arch == 'vgg16': self.net = vgg_cub.vgg16() elif self.args.arch == 'vgg16_bn': self.net = vgg_cub.vgg16_bn() elif self.args.arch == 'vgg19': self.net = vgg_cub.vgg19() elif self.args.arch == 'vgg19_bn': self.net = vgg_cub.vgg19_bn() elif self.args.arch == 'vgg16_std': self.net = vgg_std.vgg16() elif self.args.arch == 'vgg16_bn_std': self.net = vgg_std.vgg16_bn() elif self.args.arch == 'mobilenetv2': self.net = mobilenetv2.mobilenet_v2(num_classes=args.num_classes) if self.args.load_model is not None: self.net.load_state_dict(torch.load(self.args.load_model), strict=True) print('load model from %s' % self.args.load_model) elif self.args.pretrained_model is not None: self.net.load_state_dict(torch.load(self.args.pretrained_model), strict=False) print('load pretrained model form %s' % self.args.pretrained_model) else: print('not load any model, will train from scrach!') if args.expname is None: args.expname = 'runs/{}_{}_{}'.format(args.arch, args.dataset, args.method) os.makedirs(args.expname, exist_ok=True)
def main(args): random_seed = args.seed np.random.seed(random_seed) random.seed(random_seed) torch.manual_seed(random_seed) torch.cuda.manual_seed(random_seed) torch.cuda.manual_seed_all(random_seed) torch.backends.cudnn.deterministic = True # need to set to True as well print( 'Using {}\nTest on {}\nRandom Seed {}\nk_cc {}\nk_outlier {}\nevery n epoch {}\n' .format(args.net, args.dataset, args.seed, args.k_cc, args.k_outlier, args.every)) # -- training parameters -- if args.dataset == 'cifar10' or args.dataset == 'cifar100': num_epoch = 180 milestone = [int(x) for x in args.milestone.split(',')] batch_size = 128 elif args.dataset == 'pc': num_epoch = 90 milestone = [30, 60] batch_size = 128 else: ValueError('Invalid Dataset!') start_epoch = 0 num_workers = args.nworker weight_decay = 1e-4 gamma = 0.5 lr = 0.001 which_data_set = args.dataset # 'cifar100', 'cifar10', 'pc' noise_level = args.noise # noise level noise_type = args.type # "uniform", "asymmetric" train_val_ratio = 0.8 which_net = args.net # "resnet34", "resnet18", "pc" # -- denoising related parameters -- k_cc = args.k_cc k_outlier = args.k_outlier when_to_denoise = args.start_clean # starting from which epoch we denoise denoise_every_n_epoch = args.every # every n epoch, we perform denoising # -- specify dataset -- # data augmentation if which_data_set[:5] == 'cifar': transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) else: transform_train = None transform_test = None if which_data_set == 'cifar10': trainset = CIFAR10(root='./data', split='train', train_ratio=train_val_ratio, download=True, transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, worker_init_fn=_init_fn) valset = CIFAR10(root='./data', split='val', train_ratio=train_val_ratio, download=True, transform=transform_test) valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=num_workers) testset = CIFAR10(root='./data', split='test', download=True, transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers) num_class = 10 in_channel = 3 elif which_data_set == 'cifar100': trainset = CIFAR100(root='./data', split='train', train_ratio=train_val_ratio, download=True, transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, worker_init_fn=_init_fn) valset = CIFAR100(root='./data', split='val', train_ratio=train_val_ratio, download=True, transform=transform_test) valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=num_workers) testset = CIFAR100(root='./data', split='test', download=True, transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers) num_class = 100 in_channel = 3 elif which_data_set == 'pc': trainset = ModelNet40(split='train', train_ratio=train_val_ratio, num_ptrs=1024, random_jitter=True) trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, worker_init_fn=_init_fn, drop_last=True) valset = ModelNet40(split='val', train_ratio=train_val_ratio, num_ptrs=1024) valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=num_workers) testset = ModelNet40(split='test', num_ptrs=1024) testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers) num_class = 40 else: ValueError('Invalid Dataset!') print('train data size:', len(trainset)) print('validation data size:', len(valset)) print('test data size:', len(testset)) ntrain = len(trainset) # -- generate noise -- y_train = trainset.get_data_labels() y_train = np.array(y_train) noise_y_train = None keep_indices = None p = None if noise_type == 'none': pass else: if noise_type == "uniform": noise_y_train, p, keep_indices = noisify_with_P( y_train, nb_classes=num_class, noise=noise_level, random_state=random_seed) trainset.update_corrupted_label(noise_y_train) print("apply uniform noise") else: if which_data_set == 'cifar10': noise_y_train, p, keep_indices = noisify_cifar10_asymmetric( y_train, noise=noise_level, random_state=random_seed) elif which_data_set == 'cifar100': noise_y_train, p, keep_indices = noisify_cifar100_asymmetric( y_train, noise=noise_level, random_state=random_seed) elif which_data_set == 'pc': noise_y_train, p, keep_indices = noisify_modelnet40_asymmetric( y_train, noise=noise_level, random_state=random_seed) trainset.update_corrupted_label(noise_y_train) print("apply asymmetric noise") print("clean data num:", len(keep_indices)) print("probability transition matrix:\n{}".format(p)) # -- create log file -- file_name = '(' + which_data_set + '_' + which_net + ')' \ + 'type_' + noise_type + '_noise_' + str(noise_level) \ + '_k_cc_' + str(k_cc) + '_k_outlier_' + str(k_outlier) + '_start_' + str(when_to_denoise) \ + '_every_' + str(denoise_every_n_epoch) + '.txt' log_dir = check_folder('logs/logs_txt_' + str(random_seed)) file_name = os.path.join(log_dir, file_name) saver = open(file_name, "w") saver.write( 'noise type: {}\nnoise level: {}\nk_cc: {}\nk_outlier: {}\nwhen_to_apply_epoch: {}\n' .format(noise_type, noise_level, k_cc, k_outlier, when_to_denoise)) if noise_type != 'none': saver.write('total clean data num: {}\n'.format(len(keep_indices))) saver.write('probability transition matrix:\n{}\n'.format(p)) saver.flush() # -- set network, optimizer, scheduler, etc -- if which_net == 'resnet18': net = resnet18(in_channel=in_channel, num_classes=num_class) feature_size = 512 elif which_net == 'resnet34': net = resnet34(in_channel=in_channel, num_classes=num_class) feature_size = 512 elif which_net == 'pc': net = PointNetCls(k=num_class) feature_size = 256 else: ValueError('Invalid network!') optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") net = net.to(device) ################################################ exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=milestone, gamma=gamma) criterion = nn.NLLLoss() # since the output of network is by log softmax # -- misc -- best_acc = 0 best_epoch = 0 best_weights = None curr_trainloader = trainloader big_comp = set() patience = args.patience no_improve_counter = 0 # -- start training -- for epoch in range(start_epoch, num_epoch): train_correct = 0 train_loss = 0 train_total = 0 exp_lr_scheduler.step() net.train() print("current train data size:", len(curr_trainloader.dataset)) for _, (images, labels, _) in enumerate(curr_trainloader): if images.size( 0 ) == 1: # when batch size equals 1, skip, due to batch normalization continue images, labels = images.to(device), labels.to(device) outputs, features = net(images) log_outputs = torch.log_softmax(outputs, 1) loss = criterion(log_outputs, labels) optimizer.zero_grad() loss.backward() optimizer.step() train_loss += loss.item() train_total += images.size(0) _, predicted = outputs.max(1) train_correct += predicted.eq(labels).sum().item() train_acc = train_correct / train_total * 100. cprint('epoch: {}'.format(epoch), 'white') cprint( 'train accuracy: {}\ntrain loss:{}'.format(train_acc, train_loss), 'yellow') # --- compute big connected components --- net.eval() features_all = torch.zeros(ntrain, feature_size).to(device) prob_all = torch.zeros(ntrain, num_class) labels_all = torch.zeros(ntrain, num_class) train_gt_labels = np.zeros(ntrain, dtype=np.uint64) train_pred_labels = np.zeros(ntrain, dtype=np.uint64) for _, (images, labels, indices) in enumerate(trainloader): images = images.to(device) outputs, features = net(images) softmax_outputs = torch.softmax(outputs, 1) features_all[indices] = features.detach() prob_all[indices] = softmax_outputs.detach().cpu() tmp_zeros = torch.zeros(labels.shape[0], num_class) tmp_zeros[torch.arange(labels.shape[0]), labels] = 1.0 labels_all[indices] = tmp_zeros train_gt_labels[indices] = labels.cpu().numpy().astype(np.int64) train_pred_labels[indices] = labels.cpu().numpy().astype(np.int64) if epoch >= when_to_denoise and ( epoch - when_to_denoise) % denoise_every_n_epoch == 0: cprint('\n>> Computing Big Components <<', 'white') labels_all = labels_all.numpy() train_gt_labels = train_gt_labels.tolist() train_pred_labels = np.squeeze(train_pred_labels).ravel().tolist() _, idx_of_comp_idx2 = calc_topo_weights_with_components_idx( ntrain, labels_all, features_all, train_gt_labels, train_pred_labels, k=k_cc, use_log=False, cp_opt=3, nclass=num_class) # --- update largest connected component --- cur_big_comp = list(set(range(ntrain)) - set(idx_of_comp_idx2)) big_comp = big_comp.union(set(cur_big_comp)) # --- remove outliers in largest connected component --- big_com_idx = list(big_comp) feats_big_comp = features_all[big_com_idx] labels_big_comp = np.array(train_gt_labels)[big_com_idx] knnG_list = calc_knn_graph(feats_big_comp, k=args.k_outlier) knnG_list = np.array(knnG_list) knnG_shape = knnG_list.shape knn_labels = labels_big_comp[knnG_list.ravel()] knn_labels = np.reshape(knn_labels, knnG_shape) majority, counts = mode(knn_labels, axis=-1) majority = majority.ravel() counts = counts.ravel() if args.zeta > 1.0: # use majority vote non_outlier_idx = np.where(majority == labels_big_comp)[0] outlier_idx = np.where(majority != labels_big_comp)[0] outlier_idx = np.array(list(big_comp))[outlier_idx] print(">> majority == labels_big_comp -> size: ", len(non_outlier_idx)) else: # zeta filtering non_outlier_idx = np.where((majority == labels_big_comp) & ( counts >= args.k_outlier * args.zeta))[0] print(">> zeta {}, then non_outlier_idx -> size: {}".format( args.zeta, len(non_outlier_idx))) outlier_idx = np.where(majority != labels_big_comp)[0] outlier_idx = np.array(list(big_comp))[outlier_idx] cprint(">> The number of outliers: {}".format(len(outlier_idx)), 'red') cprint( ">> The purity of outliers: {}".format( np.sum(y_train[outlier_idx] == noise_y_train[outlier_idx]) / float(len(outlier_idx))), 'red') big_comp = np.array(list(big_comp))[non_outlier_idx] big_comp = set(big_comp.tolist()) # --- construct updated dataset set, which contains the collected clean data --- if which_data_set == 'cifar10': trainset_ignore_noisy_data = CIFAR10( root='./data', split='train', train_ratio=train_val_ratio, download=True, transform=transform_train) elif which_data_set == 'cifar100': trainset_ignore_noisy_data = CIFAR100( root='./data', split='train', train_ratio=train_val_ratio, download=True, transform=transform_train) else: trainset_ignore_noisy_data = ModelNet40( split='train', train_ratio=train_val_ratio, num_ptrs=1024, random_jitter=True) trainloader_ignore_noisy_data = torch.utils.data.DataLoader( trainset_ignore_noisy_data, batch_size=batch_size, shuffle=True, num_workers=num_workers, worker_init_fn=_init_fn, drop_last=True) curr_trainloader = trainloader_ignore_noisy_data noisy_data_indices = list(set(range(ntrain)) - big_comp) trainset_ignore_noisy_data.update_corrupted_label(noise_y_train) trainset_ignore_noisy_data.ignore_noise_data(noisy_data_indices) clean_data_num = len(big_comp.intersection(set(keep_indices))) noise_data_num = len(big_comp) - clean_data_num print("Big Comp Number:", len(big_comp)) print("Found Noisy Data Number:", noise_data_num) print("Found True Data Number:", clean_data_num) # compute purity of the component cc_size = len(big_comp) equal = np.sum( noise_y_train[list(big_comp)] == y_train[list(big_comp)]) ratio = equal / float(cc_size) print("Purity of current component: {}".format(ratio)) noise_size = len(noisy_data_indices) equal = np.sum(noise_y_train[noisy_data_indices] == y_train[noisy_data_indices]) print("Purity of data outside component: {}".format( equal / float(noise_size))) saver.write('Purity {}\tsize{}\t'.format(ratio, cc_size)) # --- validation --- val_total = 0 val_correct = 0 net.eval() with torch.no_grad(): for _, (images, labels, _) in enumerate(valloader): images, labels = images.to(device), labels.to(device) outputs, _ = net(images) val_total += images.size(0) _, predicted = outputs.max(1) val_correct += predicted.eq(labels).sum().item() val_acc = val_correct / val_total * 100. if val_acc > best_acc: best_acc = val_acc best_epoch = epoch best_weights = copy.deepcopy(net.state_dict()) no_improve_counter = 0 else: no_improve_counter += 1 if no_improve_counter >= patience: print( '>> No improvement for {} epochs. Stop at epoch {}'.format( patience, epoch)) saver.write( '>> No improvement for {} epochs. Stop at epoch {}'.format( patience, epoch)) saver.write( '>> val epoch: {}\n>> current accuracy: {}\n'.format( epoch, val_acc)) saver.write('>> best accuracy: {}\tbest epoch: {}\n\n'.format( best_acc, best_epoch)) break cprint('val accuracy: {}'.format(val_acc), 'cyan') cprint( '>> best accuracy: {}\n>> best epoch: {}\n'.format( best_acc, best_epoch), 'green') saver.write('{}\n'.format(val_acc)) # -- testing cprint('>> testing <<', 'cyan') test_total = 0 test_correct = 0 net.load_state_dict(best_weights) net.eval() with torch.no_grad(): for _, (images, labels, _) in enumerate(testloader): images, labels = images.to(device), labels.to(device) outputs, _ = net(images) test_total += images.size(0) _, predicted = outputs.max(1) test_correct += predicted.eq(labels).sum().item() test_acc = test_correct / test_total * 100. cprint('>> test accuracy: {}'.format(test_acc), 'cyan') saver.write('>> test accuracy: {}\n'.format(test_acc)) # retest on the validation set, for sanity check cprint('>> validation <<', 'cyan') val_total = 0 val_correct = 0 net.eval() with torch.no_grad(): for _, (images, labels, _) in enumerate(valloader): images, labels = images.to(device), labels.to(device) outputs, _ = net(images) val_total += images.size(0) _, predicted = outputs.max(1) val_correct += predicted.eq(labels).sum().item() val_acc = val_correct / val_total * 100. cprint('>> validation accuracy: {}'.format(val_acc), 'cyan') saver.write('>> validation accuracy: {}'.format(val_acc)) saver.close() return test_acc
def create_model(Model, num_classes): if Model == "ResNet18": model = resnet18(pretrained=False) fc_features = model.fc.in_features model.fc = nn.Linear(fc_features, num_classes) model = model.cuda() elif Model == "ResNet34": model = resnet34(pretrained=False).cuda() fc_features = model.fc.in_features model.fc = nn.Linear(fc_features, num_classes) model = model.cuda() elif Model == "ResNet50": model = resnet50(pretrained=False).cuda() fc_features = model.fc.in_features model.fc = nn.Linear(fc_features, num_classes) model = model.cuda() elif Model == "ResNet101": model = resnet101(pretrained=False).cuda() fc_features = model.fc.in_features model.fc = nn.Linear(fc_features, num_classes) model = model.cuda() elif Model == "MobileNet_v2": model = mobilenet_v2(num_classes=num_classes, pretrained=False).cuda() elif Model == "Mobilenetv3": model = mobilenetv3(n_class=num_classes, pretrained=False).cuda() elif Model == "shufflenet_v2_x0_5": model = shufflenet_v2_x0_5(pretrained=False, num_classes=num_classes).cuda() elif Model == "shufflenet_v2_x1_0": model = shufflenet_v2_x1_0(pretrained=False, num_classes=num_classes).cuda() elif Model == "shufflenet_v2_x1_5": model = shufflenet_v2_x1_5(pretrained=False, num_classes=num_classes).cuda() elif Model == "shufflenet_v2_x1_5": model = shufflenet_v2_x2_0(pretrained=False, num_classes=num_classes).cuda() elif "efficientnet" in Model: model = EfficientNet.from_pretrained(Model, num_classes=num_classes, pretrained=False).cuda() elif Model == "inception_v3": model = inception_v3(pretrained=False, num_classes=num_classes).cuda() elif Model == "mnasnet0_5": model = mnasnet0_5(pretrained=False, num_classes=num_classes).cuda() elif Model == "vgg11": model = vgg11(pretrained=False, num_classes=num_classes).cuda() elif Model == "vgg11_bn": model = vgg11_bn(pretrained=False, num_classes=num_classes).cuda() elif Model == "vgg19": model = vgg19(pretrained=False, num_classes=num_classes).cuda() elif Model == "densenet121": model = densenet121(pretrained=False).cuda() elif Model == "ResNeXt29_32x4d": model = ResNeXt29_32x4d(num_classes=num_classes).cuda() elif Model == "ResNeXt29_2x64d": model = ResNeXt29_2x64d(num_classes=num_classes).cuda() else: print("model error") # input = torch.randn(1, 3, 32, 32).cuda() # flops, params = profile(model, inputs=(input,)) # flops, params = clever_format([flops, params], "%.3f") # print("------------------------------------------------------------------------") # print(" ",Model) # print( " flops:", flops, " params:", params) # print("------------------------------------------------------------------------") return model
def main(args): random_seed = int(np.random.choice(range(1000), 1)) np.random.seed(random_seed) random.seed(random_seed) torch.manual_seed(random_seed) torch.cuda.manual_seed(random_seed) torch.cuda.manual_seed_all(random_seed) torch.backends.cudnn.deterministic = True arg_which_net = args.network arg_dataset = args.dataset arg_epoch_start = args.epoch_start lr = args.lr arg_gpu = args.gpu arg_num_gpu = args.n_gpus arg_every_n_epoch = args.every_n_epoch # interval to perform the correction arg_epoch_update = args.epoch_update # the epoch to start correction (warm-up period) arg_epoch_interval = args.epoch_interval # interval between two update of A noise_level = args.noise_level noise_type = args.noise_type # "uniform", "asymmetric", "none" train_val_ratio = 0.9 which_net = arg_which_net # "cnn" "resnet18" "resnet34" "preact_resnet18" "preact_resnet34" "preact_resnet101" "pc" num_epoch = args.n_epochs # Total training epochs print('Using {}\nTest on {}\nRandom Seed {}\nevery n epoch {}\nStart at epoch {}'. format(arg_which_net, arg_dataset, random_seed, arg_every_n_epoch, arg_epoch_start)) # -- training parameters if arg_dataset == 'mnist': milestone = [30, 60] batch_size = 64 in_channels = 1 elif arg_dataset == 'cifar10': milestone = [60, 180] batch_size = 128 in_channels = 1 elif arg_dataset == 'cifar100': milestone = [60, 180] batch_size = 128 in_channels = 1 elif arg_dataset == 'pc': milestone = [30, 60] batch_size = 128 start_epoch = 0 num_workers = 1 #gamma = 0.5 # -- specify dataset # data augmentation if arg_dataset == 'cifar10': transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) elif arg_dataset == 'cifar100': transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.507, 0.487, 0.441), (0.507, 0.487, 0.441)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.507, 0.487, 0.441), (0.507, 0.487, 0.441)), ]) elif arg_dataset == 'mnist': transform_train = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)), ]) else: transform_train = None transform_test = None if arg_dataset == 'cifar10': trainset = CIFAR10(root='./data', split='train', train_ratio=train_val_ratio, trust_ratio=0, download=True, transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, worker_init_fn=_init_fn) valset = CIFAR10(root='./data', split='val', train_ratio=train_val_ratio, trust_ratio=0, download=True, transform=transform_test) valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=num_workers) testset = CIFAR10(root='./data', split='test', download=True, transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers) images_size = [3, 32, 32] num_class = 10 in_channel = 3 elif arg_dataset == 'cifar100': trainset = CIFAR100(root='./data', split='train', train_ratio=train_val_ratio, trust_ratio=0, download=True, transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers,worker_init_fn=_init_fn) valset = CIFAR100(root='./data', split='val', train_ratio=train_val_ratio, trust_ratio=0, download=True, transform=transform_test) valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=num_workers) testset = CIFAR100(root='./data', split='test', download=True, transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers) num_class = 100 in_channel = 3 elif arg_dataset == 'mnist': trainset = MNIST(root='./data', split='train', train_ratio=train_val_ratio, download=True, transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers,worker_init_fn=_init_fn) valset = MNIST(root='./data', split='val', train_ratio=train_val_ratio, download=True, transform=transform_test) valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=num_workers) testset = MNIST(root='./data', split='test', download=True, transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers) num_class = 10 in_channel = 1 elif arg_dataset == 'pc': trainset = ModelNet40(split='train', train_ratio=train_val_ratio, num_ptrs=1024, random_jitter=True) trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, worker_init_fn=_init_fn, drop_last=True) valset = ModelNet40(split='val', train_ratio=train_val_ratio, num_ptrs=1024) valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=num_workers) testset = ModelNet40(split='test', num_ptrs=1024) testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers) num_class = 40 print('train data size:', len(trainset)) print('validation data size:', len(valset)) print('test data size:', len(testset)) eps = 1e-6 # This is the epsilon used to soft the label (not the epsilon in the paper) ntrain = len(trainset) # -- generate noise -- # y_train is ground truth labels, we should not have any access to this after the noisy labels are generated # algorithm after y_tilde is generated has nothing to do with y_train y_train = trainset.get_data_labels() y_train = np.array(y_train) noise_y_train = None keep_indices = None p = None if(noise_type == 'none'): pass else: if noise_type == "uniform": noise_y_train, p, keep_indices = noisify_with_P(y_train, nb_classes=num_class, noise=noise_level, random_state=random_seed) trainset.update_corrupted_label(noise_y_train) noise_softlabel = torch.ones(ntrain, num_class)*eps/(num_class-1) noise_softlabel.scatter_(1, torch.tensor(noise_y_train.reshape(-1, 1)), 1-eps) trainset.update_corrupted_softlabel(noise_softlabel) print("apply uniform noise") else: if arg_dataset == 'cifar10': noise_y_train, p, keep_indices = noisify_cifar10_asymmetric(y_train, noise=noise_level, random_state=random_seed) elif arg_dataset == 'cifar100': noise_y_train, p, keep_indices = noisify_cifar100_asymmetric(y_train, noise=noise_level, random_state=random_seed) elif arg_dataset == 'mnist': noise_y_train, p, keep_indices = noisify_mnist_asymmetric(y_train, noise=noise_level, random_state=random_seed) elif arg_dataset == 'pc': noise_y_train, p, keep_indices = noisify_modelnet40_asymmetric(y_train, noise=noise_level, random_state=random_seed) trainset.update_corrupted_label(noise_y_train) noise_softlabel = torch.ones(ntrain, num_class) * eps / (num_class - 1) noise_softlabel.scatter_(1, torch.tensor(noise_y_train.reshape(-1, 1)), 1 - eps) trainset.update_corrupted_softlabel(noise_softlabel) print("apply asymmetric noise") print("clean data num:", len(keep_indices)) print("probability transition matrix:\n{}".format(p)) # -- create log file file_name = '[' + arg_dataset + '_' + which_net + ']' \ + 'type:' + noise_type + '_' + 'noise:' + str(noise_level) + '_' \ + '_' + 'start:' + str(arg_epoch_start) + '_' \ + 'every:' + str(arg_every_n_epoch) + '_'\ + 'time:' + str(datetime.datetime.now()) + '.txt' log_dir = check_folder('new_logs/logs_txt_' + str(random_seed)) file_name = os.path.join(log_dir, file_name) saver = open(file_name, "w") saver.write('noise type: {}\nnoise level: {}\nwhen_to_apply_epoch: {}\n'.format( noise_type, noise_level, arg_epoch_start)) if noise_type != 'none': saver.write('total clean data num: {}\n'.format(len(keep_indices))) saver.write('probability transition matrix:\n{}\n'.format(p)) saver.flush() # -- set network, optimizer, scheduler, etc if which_net == "cnn": net_trust = CNN9LAYER(input_channel=in_channel, n_outputs=num_class) net = CNN9LAYER(input_channel=in_channel, n_outputs=num_class) net.apply(weight_init) feature_size = 128 elif which_net == 'resnet18': net_trust = resnet18(in_channel=in_channel, num_classes=num_class) net = resnet18(in_channel=in_channel, num_classes=num_class) feature_size = 512 elif which_net == 'resnet34': net_trust = resnet34(in_channel=in_channel, num_classes=num_class) net = resnet34(in_channel=in_channel, num_classes=num_class) feature_size = 512 elif which_net == 'preact_resnet18': net_trust = preact_resnet18(num_classes=num_class, num_input_channels=in_channel) net = preact_resnet18(num_classes=num_class, num_input_channels=in_channel) feature_size = 256 elif which_net == 'preact_resnet34': net_trust = preact_resnet34(num_classes=num_class, num_input_channels=in_channel) net = preact_resnet34(num_classes=num_class, num_input_channels=in_channel) feature_size = 256 elif which_net == 'preact_resnet101': net_trust = preact_resnet101() net = preact_resnet101() feature_size = 256 elif which_net == 'pc': net_trust = PointNetCls(k=num_class) net = PointNetCls(k=num_class) feature_size = 256 else: ValueError('Invalid network!') opt_gpus = [i for i in range(arg_gpu, arg_gpu+int(arg_num_gpu))] if len(opt_gpus) > 1: print("Using ", len(opt_gpus), " GPUs") os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(x) for x in opt_gpus) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(device) if len(opt_gpus) > 1: net_trust = torch.nn.DataParallel(net_trust) net = torch.nn.DataParallel(net) net_trust.to(device) net.to(device) # net.apply(conv_init) # ------------------------- Initialize Setting ------------------------------ best_acc = 0 best_epoch = 0 patience = 50 no_improve_counter = 0 A = 1/num_class*torch.ones(ntrain, num_class, num_class, requires_grad=False).float().to(device) h = np.zeros([ntrain, num_class]) criterion_1 = nn.NLLLoss() pred_softlabels = np.zeros([ntrain, arg_every_n_epoch, num_class], dtype=np.float) train_acc_record = [] clean_train_acc_record = [] noise_train_acc_record = [] val_acc_record = [] recovery_record = [] noise_ytrain = copy.copy(noise_y_train) #noise_ytrain = torch.tensor(noise_ytrain).to(device) cprint("================ Clean Label... ================", "yellow") for epoch in range(num_epoch): # Add some modification here train_correct = 0 train_loss = 0 train_total = 0 delta = 1.2 + 0.02*max(epoch - arg_epoch_update + 1, 0) clean_train_correct = 0 noise_train_correct = 0 optimizer_trust = RAdam(net_trust.parameters(), lr=learning_rate(lr, epoch), weight_decay=5e-4) #optimizer_trust = optim.SGD(net_trust.parameters(), lr=learning_rate(lr, epoch), weight_decay=5e-4, # nesterov=True, momentum=0.9) net_trust.train() # Train with noisy data for i, (images, labels, softlabels, indices) in enumerate(tqdm(trainloader, ncols=100, ascii=True)): if images.size(0) == 1: # when batch size equals 1, skip, due to batch normalization continue images, labels, softlabels = images.to(device), labels.to(device), softlabels.to(device) outputs, features = net_trust(images) log_outputs = torch.log_softmax(outputs, 1).float() # arg_epoch_start : epoch start to introduce loss retro # arg_epoch_interval : epochs between two updating of A if epoch in [arg_epoch_start-1, arg_epoch_start+arg_epoch_interval-1]: h[indices] = log_outputs.detach().cpu() normal_outputs = torch.softmax(outputs, 1) if epoch >= arg_epoch_start: # use loss_retro + loss_ce A_batch = A[indices].to(device) loss = sum([-A_batch[i].matmul(softlabels[i].reshape(-1, 1).float()).t().matmul(log_outputs[i]) for i in range(len(indices))]) / len(indices) + \ criterion_1(log_outputs, labels) else: # use loss_ce loss = criterion_1(log_outputs, labels) optimizer_trust.zero_grad() loss.backward() optimizer_trust.step() #arg_every_n_epoch : rolling windows to get eta_tilde if epoch >= (arg_epoch_update - arg_every_n_epoch): pred_softlabels[indices, epoch % arg_every_n_epoch, :] = normal_outputs.detach().cpu().numpy() train_loss += loss.item() train_total += images.size(0) _, predicted = outputs.max(1) train_correct += predicted.eq(labels).sum().item() # For monitoring purpose, comment the following line out if you have your own dataset that doesn't have ground truth label train_label_clean = torch.tensor(y_train)[indices].to(device) train_label_noise = torch.tensor(noise_ytrain[indices]).to(device) clean_train_correct += predicted.eq(train_label_clean).sum().item() noise_train_correct += predicted.eq(train_label_noise).sum().item() # acc wrt the original noisy labels train_acc = train_correct / train_total * 100 clean_train_acc = clean_train_correct/train_total*100 noise_train_acc = noise_train_correct/train_total*100 print(" Train Epoch: [{}/{}] \t Training Acc wrt Corrected {:.3f} \t Train Acc wrt True {:.3f} \t Train Acc wrt Noise {:.3f}". format(epoch, num_epoch, train_acc, clean_train_acc, noise_train_acc)) # updating A if epoch in [arg_epoch_start - 1, arg_epoch_start + 40 - 1]: cprint("+++++++++++++++++ Updating A +++++++++++++++++++", "magenta") unsolved = 0 infeasible = 0 y_soft = trainset.get_data_softlabel() with torch.no_grad(): for i in tqdm(range(ntrain), ncols=100, ascii=True): try: result, A_opt = updateA(y_soft[i], h[i], rho=0.9) except: A[i] = A[i] unsolved += 1 continue if (result == np.inf): A[i] = A[i] infeasible += 1 else: A[i] = torch.tensor(A_opt) print(A[0]) print("Unsolved points: {} | Infeasible points: {}".format(unsolved, infeasible)) # applying lRT scheme # args_epoch_update : epoch to update labels if epoch >= arg_epoch_update: y_tilde = trainset.get_data_labels() pred_softlabels_bar = pred_softlabels.mean(1) clean_labels, clean_softlabels = lrt_flip_scheme(pred_softlabels_bar, y_tilde, delta) trainset.update_corrupted_softlabel(clean_softlabels) trainset.update_corrupted_label(clean_softlabels.argmax(1)) # validation if not (epoch % 5): val_total = 0 val_correct = 0 net_trust.eval() with torch.no_grad(): for i, (images, labels, _, _) in enumerate(valloader): images, labels = images.to(device), labels.to(device) outputs, _ = net_trust(images) val_total += images.size(0) _, predicted = outputs.max(1) val_correct += predicted.eq(labels).sum().item() val_acc = val_correct / val_total * 100 train_acc_record.append(train_acc) val_acc_record.append(val_acc) clean_train_acc_record.append(clean_train_acc) noise_train_acc_record.append(noise_train_acc) recovery_acc = np.sum(trainset.get_data_labels() == y_train) / ntrain recovery_record.append(recovery_acc) if val_acc > best_acc: best_acc = val_acc best_epoch = epoch no_improve_counter = 0 else: no_improve_counter += 1 if no_improve_counter >= patience: print('>> No improvement for {} epochs. Stop at epoch {}'.format(patience, epoch)) saver.write('>> No improvement for {} epochs. Stop at epoch {}'.format(patience, epoch)) saver.write('>> val epoch: {}\n>> current accuracy: {}%\n'.format(epoch, val_acc)) saver.write('>> best accuracy: {}\tbest epoch: {}\n\n'.format(best_acc, best_epoch)) break cprint('val accuracy: {}'.format(val_acc), 'cyan') cprint('>> best accuracy: {}\n>> best epoch: {}\n'.format(best_acc, best_epoch), 'green') cprint('>> final recovery rate: {}\n'.format(recovery_acc), 'green') saver.write('>> val epoch: {}\n>> current accuracy: {}%\n'.format(epoch, val_acc)) saver.write("outputs: {}\n".format(normal_outputs)) saver.write('>> best accuracy: {}\tbest epoch: {}\n\n'.format(best_acc, best_epoch)) saver.write( '>> final recovery rate: {}%\n'.format(np.sum(trainset.get_data_labels() == y_train) / ntrain * 100)) saver.flush() # If want to train the neural network again with corrected labels and original loss_ce again # set args.two_stage to True print("Use Two-Stage Model {}".format(args.two_stage)) if args.two_stage==True: criterion_2 = nn.NLLLoss() best_acc = 0 best_epoch = 0 patience = 50 no_improve_counter = 0 cprint("================ Normal Training ================", "yellow") for epoch in range(num_epoch): # Add some modification here train_correct = 0 train_loss = 0 train_total = 0 optimizer_trust = optim.SGD(net.parameters(), momentum=0.9, nesterov=True, lr=learning_rate(lr, epoch), weight_decay=5e-4) net.train() for i, (images, labels, softlabels, indices) in enumerate(tqdm(trainloader, ncols=100, ascii=True)): if images.size(0) == 1: # when batch size equals 1, skip, due to batch normalization continue images, labels, softlabels = images.to(device), labels.to(device), softlabels.to(device) outputs, features = net(images) log_outputs = torch.log_softmax(outputs, 1).float() loss = criterion_2(log_outputs, labels) optimizer_trust.zero_grad() loss.backward() optimizer_trust.step() train_loss += loss.item() train_total += images.size(0) _, predicted = outputs.max(1) train_correct += predicted.eq(labels).sum().item() train_acc = train_correct / train_total * 100 print(" Train Epoch: [{}/{}] \t Training Accuracy {}%".format(epoch, num_epoch, train_acc)) if not (epoch % 5): val_total = 0 val_correct = 0 net_trust.eval() with torch.no_grad(): for i, (images, labels, _, _) in enumerate(valloader): images, labels = images.to(device), labels.to(device) outputs, _ = net(images) val_total += images.size(0) _, predicted = outputs.max(1) val_correct += predicted.eq(labels).sum().item() val_acc = val_correct / val_total * 100 if val_acc > best_acc: best_acc = val_acc best_epoch = epoch no_improve_counter = 0 else: no_improve_counter += 1 if no_improve_counter >= patience: print('>> No improvement for {} epochs. Stop at epoch {}'.format(patience, epoch)) saver.write('>> No improvement for {} epochs. Stop at epoch {}'.format(patience, epoch)) saver.write('>> val epoch: {}\n>> current accuracy: {}%\n'.format(epoch, val_acc)) saver.write('>> best accuracy: {}\tbest epoch: {}\n\n'.format(best_acc, best_epoch)) break cprint('val accuracy: {}'.format(val_acc), 'cyan') cprint('>> best accuracy: {}\n>> best epoch: {}\n'.format(best_acc, best_epoch), 'green') cprint('>> final recovery rate: {}\n'.format(np.sum(trainset.get_data_labels() == y_train) / ntrain), 'green') cprint("================ Start Testing ================", "yellow") test_total = 0 test_correct = 0 net_trust.eval() for i, (images, labels, softlabels, indices) in enumerate(testloader): if images.shape[0] == 1: continue images, labels = images.to(device), labels.to(device) outputs, _ = net_trust(images) test_total += images.shape[0] test_correct += outputs.argmax(1).eq(labels).sum().item() test_acc = test_correct/test_total*100 print("Final test accuracy {} %".format(test_correct/test_total*100)) return test_acc
def main_worker(gpu, ngpus_per_node, args): global best_acc1, best_loc1, best_epoch, \ loc1_at_best_acc1, acc1_at_best_loc1, \ gtknown_at_best_acc1, gtknown_at_best_loc1 global writer args.gpu = gpu log_folder = os.path.join('train_log', args.name, ts) args.save_dir = log_folder if args.gpu == 0: writer = SummaryWriter(logdir=log_folder) if not os.path.isdir(log_folder): os.makedirs(log_folder, exist_ok=True) with open('{}/args.json'.format(log_folder), 'w') as fp: json.dump(args.__dict__, fp) Logger(os.path.join(log_folder, 'log.log')) print('args: ', args) if args.gpu is not None: print("Use GPU: {} for training".format(args.gpu)) if args.distributed: if args.dist_url == "env://" and args.rank == -1: args.rank = int(os.environ["RANK"]) if args.multiprocessing_distributed: # For multiprocessing distributed training, rank needs to be the # global rank among all the processes args.rank = args.rank * ngpus_per_node + gpu dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) if args.dataset == 'CUB': num_classes = 200 elif args.dataset == 'tiny_imagenet': num_classes = 200 elif args.dataset == 'ILSVRC': num_classes = 1000 else: raise Exception("Not preferred dataset.") if args.arch == 'vgg16': model = vgg.vgg16(pretrained=True, num_classes=num_classes) elif args.arch == 'vgg16_GAP': model = vgg.vgg16_GAP(pretrained=True, num_classes=num_classes) elif args.arch == 'vgg16_ADL': model = vgg.vgg16_ADL(pretrained=True, num_classes=num_classes, ADL_position=args.ADL_position, drop_rate=args.ADL_rate, drop_thr=args.ADL_thr) elif args.arch == 'resnet50_ADL': model = resnet.resnet50(pretrained=True, num_classes=num_classes, ADL_position=args.ADL_position, drop_rate=args.ADL_rate, drop_thr=args.ADL_thr) elif args.arch == 'resnet50': model = resnet.resnet50(pretrained=True, num_classes=num_classes) elif args.arch == 'resnet34_ADL': model = resnet.resnet34(pretrained=True, num_classes=num_classes, ADL_position=args.ADL_position, drop_rate=args.ADL_rate, drop_thr=args.ADL_thr) elif args.arch == 'se_resnet50_ADL': model = resnet.resnet50_se(pretrained=True, num_classes=num_classes, ADL_position=args.ADL_position, drop_rate=args.ADL_rate, drop_thr=args.ADL_thr) else: model = None if args.distributed: # For multiprocessing distributed, DistributedDataParallel constructor # should always set the single device scope, otherwise, # DistributedDataParallel will use all available devices. if args.gpu is not None: torch.cuda.set_device(args.gpu) model.cuda(args.gpu) # When using a single GPU per process and per # DistributedDataParallel, we need to divide the batch size # ourselves based on the total number of GPUs we have args.batch_size = int(args.batch_size / ngpus_per_node) args.workers = int(args.workers / ngpus_per_node) model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.gpu], find_unused_parameters=True) else: model.cuda() # DistributedDataParallel will divide and allocate batch_size to all # available GPUs if device_ids are not set model = torch.nn.parallel.DistributedDataParallel(model) elif args.gpu is not None: torch.cuda.set_device(args.gpu) model = model.cuda(args.gpu) else: # DataParallel will divide and allocate batch_size to all available GPUs if args.arch.startswith('alexnet') or args.arch.startswith('vgg'): model.features = torch.nn.DataParallel(model.features) model.cuda() else: model = torch.nn.DataParallel(model).cuda() # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda(args.gpu) param_features = [] param_classifiers = [] if args.arch.startswith('vgg'): for name, parameter in model.named_parameters(): if 'features.' in name: param_features.append(parameter) else: param_classifiers.append(parameter) elif args.arch.startswith('resnet') or args.arch.startswith('se'): for name, parameter in model.named_parameters(): if 'layer4.' in name or 'fc.' in name: param_classifiers.append(parameter) else: param_features.append(parameter) else: raise Exception("Fail to recognize the architecture") optimizer = torch.optim.SGD([{ 'params': param_features, 'lr': args.lr }, { 'params': param_classifiers, 'lr': args.lr * args.lr_ratio }], momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nest) # optionally resume from a checkpoint if args.resume: model, optimizer = load_model(model, optimizer, args) # for param_group in optimizer.param_groups: # param_group['lr'] = args.lr cudnn.benchmark = True # CUB-200-2011 train_loader, val_loader, train_sampler = data_loader(args) if args.cam_curve: cam_curve(val_loader, model, criterion, writer, args) return if args.evaluate: evaluate(val_loader, model, criterion, args) return if args.gpu == 0: print("Batch Size per Tower: %d" % (args.batch_size)) print(model) for epoch in range(args.start_epoch, args.epochs): if args.gpu == 0: print( "===========================================================") print("Start Epoch %d ..." % (epoch + 1)) if args.distributed: train_sampler.set_epoch(epoch) adjust_learning_rate(optimizer, epoch, args) val_acc1 = 0 val_loss = 0 val_gtloc = 0 val_loc = 0 # train for one epoch train_acc, train_loss, progress_train = \ train(train_loader, model, criterion, optimizer, epoch, args) if args.gpu == 0: progress_train.display(epoch + 1) # evaluate on validation set if args.task == 'cls': val_acc1, val_loss = validate(val_loader, model, criterion, epoch, args) # evaluate localization on validation set elif args.task == 'wsol': val_acc1, val_acc5, val_loss, \ val_gtloc, val_loc = evaluate_loc(val_loader, model, criterion, epoch, args) # tensorboard if args.gpu == 0: writer.add_scalar(args.name + '/train_acc', train_acc, epoch) writer.add_scalar(args.name + '/train_loss', train_loss, epoch) writer.add_scalar(args.name + '/val_cls_acc', val_acc1, epoch) writer.add_scalar(args.name + '/val_loss', val_loss, epoch) writer.add_scalar(args.name + '/val_gt_loc', val_gtloc, epoch) writer.add_scalar(args.name + '/val_loc1', val_loc, epoch) # remember best acc@1 and save checkpoint is_best = val_acc1 > best_acc1 best_acc1 = max(val_acc1, best_acc1) if is_best: best_epoch = epoch + 1 loc1_at_best_acc1 = val_loc gtknown_at_best_acc1 = val_gtloc if args.task == 'wsol': # in case best loc,, Not using this. is_best_loc = val_loc > best_loc1 best_loc1 = max(val_loc, best_loc1) if is_best_loc: best_epoch = epoch + 1 acc1_at_best_loc1 = val_acc1 gtknown_at_best_loc1 = val_gtloc if args.gpu == 0: print("\nCurrent Best Epoch: %d" % (best_epoch)) print("Top-1 GT-Known Localization Acc: %.3f \ \nTop-1 Localization Acc: %.3f\ \nTop-1 Classification Acc: %.3f" % \ (gtknown_at_best_acc1, loc1_at_best_acc1, best_acc1)) print("\nEpoch %d finished." % (epoch + 1)) if not args.multiprocessing_distributed or ( args.multiprocessing_distributed and args.rank % ngpus_per_node == 0): saving_dir = os.path.join(log_folder) save_checkpoint( { 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_acc1': best_acc1, 'optimizer': optimizer.state_dict(), }, is_best, saving_dir) if args.gpu == 0: save_train(best_acc1, loc1_at_best_acc1, gtknown_at_best_acc1, best_loc1, acc1_at_best_loc1, gtknown_at_best_loc1, args) print("===========================================================") print("Start Evaluation on Best Checkpoint ...") args.resume = os.path.join(log_folder, 'model_best.pth.tar') model, _ = load_model(model, optimizer, args) evaluate(val_loader, model, criterion, args) cam_curve(val_loader, model, criterion, writer, args)