def __init__(self, trunk_name, output_stride=8, pretrained=True):
        super(get_resnet, self).__init__()

        if trunk_name == "resnet18":
            resnet = resnet18(pretrained=pretrained)
            resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1,
                                          resnet.relu, resnet.maxpool)
        elif trunk_name == "resnet34":
            resnet = resnet34(pretrained=pretrained)
            resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1,
                                          resnet.relu, resnet.maxpool)
        elif trunk_name == "resnet50":
            resnet = resnet50(pretrained=pretrained)
            resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1,
                                          resnet.relu, resnet.maxpool)
        elif trunk_name == "resnet101":
            resnet = resnet101(pretrained=pretrained)
            resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1,
                                          resnet.relu, resnet.maxpool)
        elif trunk_name == "resnet152":
            resnet = resnet152(pretrained=pretrained)
            resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1,
                                          resnet.relu, resnet.maxpool)
        elif trunk_name == "resnext101_32x8d":
            resnet = resnext101_32x8d(pretrained=pretrained)
            resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1,
                                          resnet.relu, resnet.maxpool)
        elif trunk_name == "se_resnet50":
            resnet = se_resnext50_32x4d(pretrained=pretrained)
        elif trunk_name == "se_resnet101":
            resnet = se_resnext101_32x4d(pretrained=pretrained)
        else:
            raise KeyError("[*] Not a valid network arch")

        self.layer0 = resnet.layer0
        self.layer1, self.layer2, self.layer3, self.layer4 = \
            resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4

        if output_stride == 8:
            for n, m in self.layer3.named_modules():
                if 'conv2' in n:
                    m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
                elif 'downsample.0' in n:
                    m.stride = (1, 1)
            for n, m in self.layer4.named_modules():
                if 'conv2' in n:
                    m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
                elif 'downsample.0' in n:
                    m.stride = (1, 1)
        elif output_stride == 16:
            for n, m in self.layer4.named_modules():
                if 'conv2' in n:
                    m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
                elif 'downsample.0' in n:
                    m.stride = (1, 1)
        else:
            raise KeyError(
                "[*] Unsupported output_stride {}".format(output_stride))
示例#2
0
 def __init__(self, num_classes=256, input_bits=8, embed_dim=8, seq_len=100, backbone='resnet50'):
     super(CNN, self).__init__()
     self.input_bits = input_bits
     self.seq_len = seq_len
     # self.embed = nn.Embedding(256, 256)
     self.embed = nn.Embedding(2 ** input_bits, embed_dim)
     self.backbone = {
         'resnet18': resnet18(n_classes=num_classes, input_channels=embed_dim),
         'resnet34': resnet34(n_classes=num_classes, input_channels=embed_dim),
         'resnet50': resnet50(n_classes=num_classes, input_channels=embed_dim),
         'resnet101': resnet101(n_classes=num_classes, input_channels=embed_dim),
         'resnet152': resnet152(n_classes=num_classes, input_channels=embed_dim),
     }.get(backbone, 'resnet50')
示例#3
0
    def __init__(self, args, train_transform=None, test_transform=None, val_transform=None):
        self.args = args
        self.train_transform = train_transform
        self.test_transform = test_transform
        self.val_transform = val_transform
        
        self.loss_lams = torch.zeros(args.num_classes, args.num_classes, dtype=torch.float32).cuda()
        self.loss_lams[:, :] = 1. / args.num_classes
        self.loss_lams.requires_grad = False

        if self.args.arch == 'resnet50':
            self.net = resnet.resnet50(num_classes=args.num_classes)
        elif self.args.arch == 'res2net50':
            self.net = res2net.res2net50_26w_4s(num_classes=args.num_classes)
        elif self.args.arch == 'resnet101':
            self.net = resnet.resnet101()
        elif self.args.arch == 'resnet18':
            self.net = resnet.resnet18()
        elif self.args.arch == 'resnet34':
            self.net = resnet.resnet34()
        elif self.args.arch == 'vgg16':
            self.net = vgg_cub.vgg16()
        elif self.args.arch == 'vgg16_bn':
            self.net = vgg_cub.vgg16_bn()
        elif self.args.arch == 'vgg19':
            self.net = vgg_cub.vgg19()
        elif self.args.arch == 'vgg19_bn':
            self.net = vgg_cub.vgg19_bn()
        elif self.args.arch == 'vgg16_std':
            self.net = vgg_std.vgg16()
        elif self.args.arch == 'vgg16_bn_std':
            self.net = vgg_std.vgg16_bn()
        elif self.args.arch == 'mobilenetv2':
            self.net = mobilenetv2.mobilenet_v2(num_classes=args.num_classes)

        if self.args.load_model is not None:
            self.net.load_state_dict(torch.load(self.args.load_model), strict=True)
            print('load model from %s' % self.args.load_model)
        elif self.args.pretrained_model is not None:
            self.net.load_state_dict(torch.load(self.args.pretrained_model), strict=False)
            print('load pretrained model form %s' % self.args.pretrained_model)
        else:
            print('not load any model, will train from scrach!')

        if args.expname is None:
            args.expname = 'runs/{}_{}_{}'.format(args.arch, args.dataset, args.method)
        os.makedirs(args.expname, exist_ok=True)
示例#4
0
def testRed(num, visualize):
    # Set Test parameters
    params = TestParams()
    params.gpus = [
    ]  # set 'params.gpus=[]' to use CPU model. if len(params.gpus)>1, default to use params.gpus[0] to test
    # this model corresponds to the model trained in the 60th epoch shown in the two training results under ./architecture:
    # red_*_train_1e5_test_2e4_10_kinds_3min_per_epoch_resnet18.png
    params.ckpt = './models/formation_prediction.pth'

    # models
    # model = resnet34(pretrained=False, num_classes=1000)  # batch_size=120, 1GPU Memory < 7000M
    # model.fc = nn.Linear(512, 6)
    model = resnet18(pretrained=False,
                     num_classes=1000)  # batch_size=60, 1GPU Memory > 9000M
    model.fc = nn.Linear(512, formation_num)

    # Test
    tester = RedTester(model, params)
    tester.test(num, visualize)
示例#5
0
def main():
    MAX_EPOCH = 2
    print('=' * 3)
    sys.exit(0)

    model = ResNet.resnet18()
    model.build(input_shape=(None, 32, 32, 3))
    model.summary()
    optimizer = optimizers.Adam(lr=1e-3)
    print('=' * 6)
    sys.exit(0)

    for epoch in range(MAX_EPOCH):
        for step, (x, y) in enumerate(train_data):
            with tf.GradientTape() as tape:
                logits = model(x)
                y_onehot = tf.one_hot(y, depth=10)
                loss = tf.losses.categorical_crossentropy(y_onehot,
                                                          logits,
                                                          from_logits=True)
                loss = tf.reduce_mean(loss)
            grads = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))
            if step % 100 == 0:
                print(epoch, step, 'loss', float(loss))
        total_num = 0
        total_correct = 0
        for x, y in test_data:
            logits = model(x)
            prob = tf.nn.softmax(logits, axis=1)
            pred = tf.argmax(prob, axis=1)
            pred = tf.cast(pred, dtype=tf.int32)
            correct = tf.cast(tf.equal(pred, y), dtype=tf.int32)
            correct = tf.reduce_sum(correct)
            total_num += x.shape[0]
            total_correct += int(correct)
        acc = total_correct / total_num
        print(epoch, 'acc:', acc)
示例#6
0
def main(args):
    random_seed = args.seed

    np.random.seed(random_seed)
    random.seed(random_seed)
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)
    torch.backends.cudnn.deterministic = True  # need to set to True as well

    print(
        'Using {}\nTest on {}\nRandom Seed {}\nk_cc {}\nk_outlier {}\nevery n epoch {}\n'
        .format(args.net, args.dataset, args.seed, args.k_cc, args.k_outlier,
                args.every))

    # -- training parameters --
    if args.dataset == 'cifar10' or args.dataset == 'cifar100':
        num_epoch = 180
        milestone = [int(x) for x in args.milestone.split(',')]
        batch_size = 128
    elif args.dataset == 'pc':
        num_epoch = 90
        milestone = [30, 60]
        batch_size = 128
    else:
        ValueError('Invalid Dataset!')

    start_epoch = 0
    num_workers = args.nworker

    weight_decay = 1e-4
    gamma = 0.5
    lr = 0.001

    which_data_set = args.dataset  # 'cifar100', 'cifar10', 'pc'
    noise_level = args.noise  # noise level
    noise_type = args.type  # "uniform", "asymmetric"

    train_val_ratio = 0.8
    which_net = args.net  # "resnet34", "resnet18", "pc"

    # -- denoising related parameters --
    k_cc = args.k_cc
    k_outlier = args.k_outlier
    when_to_denoise = args.start_clean  # starting from which epoch we denoise
    denoise_every_n_epoch = args.every  # every n epoch, we perform denoising

    # -- specify dataset --
    # data augmentation
    if which_data_set[:5] == 'cifar':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])
    else:
        transform_train = None
        transform_test = None

    if which_data_set == 'cifar10':
        trainset = CIFAR10(root='./data',
                           split='train',
                           train_ratio=train_val_ratio,
                           download=True,
                           transform=transform_train)
        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=batch_size,
                                                  shuffle=True,
                                                  num_workers=num_workers,
                                                  worker_init_fn=_init_fn)

        valset = CIFAR10(root='./data',
                         split='val',
                         train_ratio=train_val_ratio,
                         download=True,
                         transform=transform_test)
        valloader = torch.utils.data.DataLoader(valset,
                                                batch_size=batch_size,
                                                shuffle=False,
                                                num_workers=num_workers)

        testset = CIFAR10(root='./data',
                          split='test',
                          download=True,
                          transform=transform_test)
        testloader = torch.utils.data.DataLoader(testset,
                                                 batch_size=batch_size,
                                                 shuffle=False,
                                                 num_workers=num_workers)

        num_class = 10
        in_channel = 3
    elif which_data_set == 'cifar100':
        trainset = CIFAR100(root='./data',
                            split='train',
                            train_ratio=train_val_ratio,
                            download=True,
                            transform=transform_train)
        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=batch_size,
                                                  shuffle=True,
                                                  num_workers=num_workers,
                                                  worker_init_fn=_init_fn)

        valset = CIFAR100(root='./data',
                          split='val',
                          train_ratio=train_val_ratio,
                          download=True,
                          transform=transform_test)
        valloader = torch.utils.data.DataLoader(valset,
                                                batch_size=batch_size,
                                                shuffle=False,
                                                num_workers=num_workers)

        testset = CIFAR100(root='./data',
                           split='test',
                           download=True,
                           transform=transform_test)
        testloader = torch.utils.data.DataLoader(testset,
                                                 batch_size=batch_size,
                                                 shuffle=False,
                                                 num_workers=num_workers)

        num_class = 100
        in_channel = 3
    elif which_data_set == 'pc':
        trainset = ModelNet40(split='train',
                              train_ratio=train_val_ratio,
                              num_ptrs=1024,
                              random_jitter=True)
        trainloader = torch.utils.data.DataLoader(trainset,
                                                  batch_size=batch_size,
                                                  shuffle=True,
                                                  num_workers=num_workers,
                                                  worker_init_fn=_init_fn,
                                                  drop_last=True)

        valset = ModelNet40(split='val',
                            train_ratio=train_val_ratio,
                            num_ptrs=1024)
        valloader = torch.utils.data.DataLoader(valset,
                                                batch_size=batch_size,
                                                shuffle=False,
                                                num_workers=num_workers)

        testset = ModelNet40(split='test', num_ptrs=1024)
        testloader = torch.utils.data.DataLoader(testset,
                                                 batch_size=batch_size,
                                                 shuffle=False,
                                                 num_workers=num_workers)

        num_class = 40
    else:
        ValueError('Invalid Dataset!')

    print('train data size:', len(trainset))
    print('validation data size:', len(valset))
    print('test data size:', len(testset))
    ntrain = len(trainset)

    # -- generate noise --
    y_train = trainset.get_data_labels()
    y_train = np.array(y_train)

    noise_y_train = None
    keep_indices = None
    p = None

    if noise_type == 'none':
        pass
    else:
        if noise_type == "uniform":
            noise_y_train, p, keep_indices = noisify_with_P(
                y_train,
                nb_classes=num_class,
                noise=noise_level,
                random_state=random_seed)
            trainset.update_corrupted_label(noise_y_train)
            print("apply uniform noise")
        else:
            if which_data_set == 'cifar10':
                noise_y_train, p, keep_indices = noisify_cifar10_asymmetric(
                    y_train, noise=noise_level, random_state=random_seed)
            elif which_data_set == 'cifar100':
                noise_y_train, p, keep_indices = noisify_cifar100_asymmetric(
                    y_train, noise=noise_level, random_state=random_seed)
            elif which_data_set == 'pc':
                noise_y_train, p, keep_indices = noisify_modelnet40_asymmetric(
                    y_train, noise=noise_level, random_state=random_seed)

            trainset.update_corrupted_label(noise_y_train)
            print("apply asymmetric noise")
        print("clean data num:", len(keep_indices))
        print("probability transition matrix:\n{}".format(p))

    # -- create log file --
    file_name = '(' + which_data_set + '_' + which_net + ')' \
                + 'type_' + noise_type + '_noise_' + str(noise_level) \
                + '_k_cc_' + str(k_cc) + '_k_outlier_' + str(k_outlier) + '_start_' + str(when_to_denoise) \
                + '_every_' + str(denoise_every_n_epoch) + '.txt'
    log_dir = check_folder('logs/logs_txt_' + str(random_seed))
    file_name = os.path.join(log_dir, file_name)
    saver = open(file_name, "w")

    saver.write(
        'noise type: {}\nnoise level: {}\nk_cc: {}\nk_outlier: {}\nwhen_to_apply_epoch: {}\n'
        .format(noise_type, noise_level, k_cc, k_outlier, when_to_denoise))
    if noise_type != 'none':
        saver.write('total clean data num: {}\n'.format(len(keep_indices)))
        saver.write('probability transition matrix:\n{}\n'.format(p))
    saver.flush()

    # -- set network, optimizer, scheduler, etc --
    if which_net == 'resnet18':
        net = resnet18(in_channel=in_channel, num_classes=num_class)
        feature_size = 512
    elif which_net == 'resnet34':
        net = resnet34(in_channel=in_channel, num_classes=num_class)
        feature_size = 512
    elif which_net == 'pc':
        net = PointNetCls(k=num_class)
        feature_size = 256
    else:
        ValueError('Invalid network!')

    optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=weight_decay)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = net.to(device)

    ################################################
    exp_lr_scheduler = lr_scheduler.MultiStepLR(optimizer,
                                                milestones=milestone,
                                                gamma=gamma)

    criterion = nn.NLLLoss()  # since the output of network is by log softmax

    # -- misc --
    best_acc = 0
    best_epoch = 0
    best_weights = None

    curr_trainloader = trainloader

    big_comp = set()

    patience = args.patience
    no_improve_counter = 0

    # -- start training --
    for epoch in range(start_epoch, num_epoch):
        train_correct = 0
        train_loss = 0
        train_total = 0

        exp_lr_scheduler.step()
        net.train()
        print("current train data size:", len(curr_trainloader.dataset))

        for _, (images, labels, _) in enumerate(curr_trainloader):
            if images.size(
                    0
            ) == 1:  # when batch size equals 1, skip, due to batch normalization
                continue
            images, labels = images.to(device), labels.to(device)

            outputs, features = net(images)
            log_outputs = torch.log_softmax(outputs, 1)

            loss = criterion(log_outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_total += images.size(0)
            _, predicted = outputs.max(1)
            train_correct += predicted.eq(labels).sum().item()

        train_acc = train_correct / train_total * 100.

        cprint('epoch: {}'.format(epoch), 'white')
        cprint(
            'train accuracy: {}\ntrain loss:{}'.format(train_acc, train_loss),
            'yellow')

        # --- compute big connected components ---
        net.eval()
        features_all = torch.zeros(ntrain, feature_size).to(device)
        prob_all = torch.zeros(ntrain, num_class)

        labels_all = torch.zeros(ntrain, num_class)
        train_gt_labels = np.zeros(ntrain, dtype=np.uint64)
        train_pred_labels = np.zeros(ntrain, dtype=np.uint64)

        for _, (images, labels, indices) in enumerate(trainloader):
            images = images.to(device)

            outputs, features = net(images)

            softmax_outputs = torch.softmax(outputs, 1)

            features_all[indices] = features.detach()
            prob_all[indices] = softmax_outputs.detach().cpu()

            tmp_zeros = torch.zeros(labels.shape[0], num_class)
            tmp_zeros[torch.arange(labels.shape[0]), labels] = 1.0
            labels_all[indices] = tmp_zeros

            train_gt_labels[indices] = labels.cpu().numpy().astype(np.int64)
            train_pred_labels[indices] = labels.cpu().numpy().astype(np.int64)

        if epoch >= when_to_denoise and (
                epoch - when_to_denoise) % denoise_every_n_epoch == 0:
            cprint('\n>> Computing Big Components <<', 'white')

            labels_all = labels_all.numpy()
            train_gt_labels = train_gt_labels.tolist()
            train_pred_labels = np.squeeze(train_pred_labels).ravel().tolist()

            _, idx_of_comp_idx2 = calc_topo_weights_with_components_idx(
                ntrain,
                labels_all,
                features_all,
                train_gt_labels,
                train_pred_labels,
                k=k_cc,
                use_log=False,
                cp_opt=3,
                nclass=num_class)

            # --- update largest connected component ---
            cur_big_comp = list(set(range(ntrain)) - set(idx_of_comp_idx2))
            big_comp = big_comp.union(set(cur_big_comp))

            # --- remove outliers in largest connected component ---
            big_com_idx = list(big_comp)
            feats_big_comp = features_all[big_com_idx]
            labels_big_comp = np.array(train_gt_labels)[big_com_idx]

            knnG_list = calc_knn_graph(feats_big_comp, k=args.k_outlier)

            knnG_list = np.array(knnG_list)
            knnG_shape = knnG_list.shape
            knn_labels = labels_big_comp[knnG_list.ravel()]
            knn_labels = np.reshape(knn_labels, knnG_shape)

            majority, counts = mode(knn_labels, axis=-1)
            majority = majority.ravel()
            counts = counts.ravel()

            if args.zeta > 1.0:  # use majority vote
                non_outlier_idx = np.where(majority == labels_big_comp)[0]
                outlier_idx = np.where(majority != labels_big_comp)[0]
                outlier_idx = np.array(list(big_comp))[outlier_idx]
                print(">> majority == labels_big_comp -> size: ",
                      len(non_outlier_idx))

            else:  # zeta filtering
                non_outlier_idx = np.where((majority == labels_big_comp) & (
                    counts >= args.k_outlier * args.zeta))[0]
                print(">> zeta {}, then non_outlier_idx -> size: {}".format(
                    args.zeta, len(non_outlier_idx)))

                outlier_idx = np.where(majority != labels_big_comp)[0]
                outlier_idx = np.array(list(big_comp))[outlier_idx]

            cprint(">> The number of outliers: {}".format(len(outlier_idx)),
                   'red')
            cprint(
                ">> The purity of outliers: {}".format(
                    np.sum(y_train[outlier_idx] == noise_y_train[outlier_idx])
                    / float(len(outlier_idx))), 'red')

            big_comp = np.array(list(big_comp))[non_outlier_idx]
            big_comp = set(big_comp.tolist())

            # --- construct updated dataset set, which contains the collected clean data ---
            if which_data_set == 'cifar10':
                trainset_ignore_noisy_data = CIFAR10(
                    root='./data',
                    split='train',
                    train_ratio=train_val_ratio,
                    download=True,
                    transform=transform_train)
            elif which_data_set == 'cifar100':
                trainset_ignore_noisy_data = CIFAR100(
                    root='./data',
                    split='train',
                    train_ratio=train_val_ratio,
                    download=True,
                    transform=transform_train)
            else:
                trainset_ignore_noisy_data = ModelNet40(
                    split='train',
                    train_ratio=train_val_ratio,
                    num_ptrs=1024,
                    random_jitter=True)

            trainloader_ignore_noisy_data = torch.utils.data.DataLoader(
                trainset_ignore_noisy_data,
                batch_size=batch_size,
                shuffle=True,
                num_workers=num_workers,
                worker_init_fn=_init_fn,
                drop_last=True)
            curr_trainloader = trainloader_ignore_noisy_data

            noisy_data_indices = list(set(range(ntrain)) - big_comp)

            trainset_ignore_noisy_data.update_corrupted_label(noise_y_train)
            trainset_ignore_noisy_data.ignore_noise_data(noisy_data_indices)

            clean_data_num = len(big_comp.intersection(set(keep_indices)))
            noise_data_num = len(big_comp) - clean_data_num
            print("Big Comp Number:", len(big_comp))
            print("Found Noisy Data Number:", noise_data_num)
            print("Found True Data Number:", clean_data_num)

            # compute purity of the component
            cc_size = len(big_comp)
            equal = np.sum(
                noise_y_train[list(big_comp)] == y_train[list(big_comp)])
            ratio = equal / float(cc_size)
            print("Purity of current component: {}".format(ratio))

            noise_size = len(noisy_data_indices)
            equal = np.sum(noise_y_train[noisy_data_indices] ==
                           y_train[noisy_data_indices])
            print("Purity of data outside component: {}".format(
                equal / float(noise_size)))

            saver.write('Purity {}\tsize{}\t'.format(ratio, cc_size))

        # --- validation ---
        val_total = 0
        val_correct = 0
        net.eval()
        with torch.no_grad():
            for _, (images, labels, _) in enumerate(valloader):
                images, labels = images.to(device), labels.to(device)

                outputs, _ = net(images)

                val_total += images.size(0)
                _, predicted = outputs.max(1)
                val_correct += predicted.eq(labels).sum().item()

        val_acc = val_correct / val_total * 100.

        if val_acc > best_acc:
            best_acc = val_acc
            best_epoch = epoch
            best_weights = copy.deepcopy(net.state_dict())
            no_improve_counter = 0
        else:
            no_improve_counter += 1
            if no_improve_counter >= patience:
                print(
                    '>> No improvement for {} epochs. Stop at epoch {}'.format(
                        patience, epoch))
                saver.write(
                    '>> No improvement for {} epochs. Stop at epoch {}'.format(
                        patience, epoch))
                saver.write(
                    '>> val epoch: {}\n>> current accuracy: {}\n'.format(
                        epoch, val_acc))
                saver.write('>> best accuracy: {}\tbest epoch: {}\n\n'.format(
                    best_acc, best_epoch))
                break

        cprint('val accuracy: {}'.format(val_acc), 'cyan')
        cprint(
            '>> best accuracy: {}\n>> best epoch: {}\n'.format(
                best_acc, best_epoch), 'green')
        saver.write('{}\n'.format(val_acc))

    # -- testing
    cprint('>> testing <<', 'cyan')
    test_total = 0
    test_correct = 0

    net.load_state_dict(best_weights)
    net.eval()
    with torch.no_grad():
        for _, (images, labels, _) in enumerate(testloader):
            images, labels = images.to(device), labels.to(device)

            outputs, _ = net(images)

            test_total += images.size(0)
            _, predicted = outputs.max(1)
            test_correct += predicted.eq(labels).sum().item()

    test_acc = test_correct / test_total * 100.

    cprint('>> test accuracy: {}'.format(test_acc), 'cyan')
    saver.write('>> test accuracy: {}\n'.format(test_acc))

    # retest on the validation set, for sanity check
    cprint('>> validation <<', 'cyan')
    val_total = 0
    val_correct = 0
    net.eval()
    with torch.no_grad():
        for _, (images, labels, _) in enumerate(valloader):
            images, labels = images.to(device), labels.to(device)

            outputs, _ = net(images)

            val_total += images.size(0)
            _, predicted = outputs.max(1)
            val_correct += predicted.eq(labels).sum().item()

    val_acc = val_correct / val_total * 100.
    cprint('>> validation accuracy: {}'.format(val_acc), 'cyan')
    saver.write('>> validation accuracy: {}'.format(val_acc))

    saver.close()

    return test_acc
示例#7
0
        params.gpus) == 0 else batch_size * len(params.gpus)

    train_dataloader = DataLoader(train_data,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=num_workers)
    print('train dataset len: {}'.format(len(train_dataloader.dataset)))

    val_dataloader = DataLoader(val_data,
                                batch_size=batch_size,
                                shuffle=False,
                                num_workers=num_workers)
    print('val dataset len: {}'.format(len(val_dataloader.dataset)))

    # models
    model = resnet18(pretrained=False, modelpath=model_path,
                     num_classes=1000)  # batch_size=120, 1GPU Memory < 7000M
    model.fc = nn.Linear(512, formation_num)
    # model = resnet101(pretrained=False, modelpath=model_path, num_classes=1000)  # batch_size=60, 1GPU Memory > 9000M
    # model.fc = nn.Linear(512*4, 6)

    # optimizer
    trainable_vars = [
        param for param in model.parameters() if param.requires_grad
    ]
    print("Training with sgd")
    params.optimizer = torch.optim.SGD(trainable_vars,
                                       lr=init_lr,
                                       momentum=momentum,
                                       weight_decay=weight_decay,
                                       nesterov=nesterov)
示例#8
0
def testBlue(total_data, visualize=True):
    dataset = [
        total_data[i]
        for i in np.random.choice(range(100000), 10, replace=False)
    ] if visualize else total_data
    images = []
    summary = [[] for i in range(formation_num)]

    model = resnet18(pretrained=False, num_classes=1000)
    model.fc = nn.Linear(512, 1)
    model.eval()
    m = torch.load("./models/state_evaluation.pth")
    new_dict = model.state_dict().copy()

    for i in range(len(m)):
        new_dict[list(model.state_dict().keys())[i]] = m[list(m.keys())[i]]

    model.load_state_dict(new_dict)

    city_position = [[255] * city_grid[1] for i in range(city_grid[0])]
    for pos in CITY_POSITION:
        city_position[pos[1]][pos[0] - 1 - ocean_grid[1]] = 0

    for data in dataset:
        base_position = data[0]

        row, col = base_position.shape

        for i in range(row):
            for j in range(col):
                if base_position[i][j]:
                    assert (base_position[i][j] <= MAX_MISSILE)
                    base_position[i][j] = (MAX_MISSILE - base_position[i][j]
                                           ) * int(255 / MAX_MISSILE)
                else:
                    base_position[i][j] = 255

        img = np.hstack(
            (base_position, 255 - 255 * data[1], np.array(city_position)))

        h, w = img.shape
        if (224 - h) % 2 != 0 or (224 - w) % 2 != 0:
            padding = (int((224 - w) / 2), 224 - int(
                (224 - w) / 2) - w, int((224 - h) / 2), 224 - int(
                    (224 - h) / 2) - h)
        else:
            padding = (int((224 - w) / 2), int((224 - h) / 2))

        transforms1 = T.Compose([T.Pad(padding, fill=255), T.ToTensor()])

        transforms2 = T.Compose([
            T.Pad((0, 29), fill=255),
        ])

        img = Image.fromarray(img.astype('uint8')).convert('L')
        eval_value = model(Variable(torch.unsqueeze(transforms1(img), 0)))
        eval_value = round(float(eval_value), 4)

        if visualize:
            image = transforms2(img)
            images.append([image, str(eval_value) + "/" + str(data[2])])
        else:
            # print(data[1], type(data[1]))
            forms = [f.tolist() for f in list(formations.values())]
            i = forms.index(data[1].tolist())
            summary[i].append(abs(eval_value - float(data[2])))

    # print(images)

    if visualize:
        plt.figure("State Evaluation")
        for i in range(1, len(images) + 1):
            plt.subplot(2, 5, i)
            plt.title(images[i - 1][1])
            plt.imshow(images[i - 1][0])
            # plt.axis('off')
        plt.show()

        if input() == " ":
            testBlue(total_data, visualize)
    else:
        avg, min_val, max_val = [], [], []
        total = 0
        for diff in summary:
            total += sum(diff)
            avg.append(round(sum(diff) / len(diff), 4))
            max_val.append(round(max(diff), 4))
            min_val.append(round(min(diff), 4))
        print(round(total / 10000), 4)
        index = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
        plt.bar(left=index, height=min_val, label="Min Absolute Difference")
        plt.bar(left=index,
                height=avg,
                bottom=min_val,
                label="Avg Absolute Difference")
        plt.bar(left=index,
                height=max_val,
                bottom=avg,
                label="Max Absolute Difference")
        plt.xlabel('Formation')
        plt.ylabel('Absolute Difference Between Evaluation and Ground Truth')
        plt.legend(loc='best')
        plt.show()
示例#9
0
def create_model(Model, num_classes):
    if Model == "ResNet18":
        model = resnet18(pretrained=False)
        fc_features = model.fc.in_features
        model.fc = nn.Linear(fc_features, num_classes)
        model = model.cuda()

    elif Model == "ResNet34":
        model = resnet34(pretrained=False).cuda()
        fc_features = model.fc.in_features
        model.fc = nn.Linear(fc_features, num_classes)
        model = model.cuda()

    elif Model == "ResNet50":
        model = resnet50(pretrained=False).cuda()
        fc_features = model.fc.in_features
        model.fc = nn.Linear(fc_features, num_classes)
        model = model.cuda()

    elif Model == "ResNet101":
        model = resnet101(pretrained=False).cuda()
        fc_features = model.fc.in_features
        model.fc = nn.Linear(fc_features, num_classes)
        model = model.cuda()

    elif Model == "MobileNet_v2":
        model = mobilenet_v2(num_classes=num_classes, pretrained=False).cuda()

    elif Model == "Mobilenetv3":
        model = mobilenetv3(n_class=num_classes, pretrained=False).cuda()

    elif Model == "shufflenet_v2_x0_5":
        model = shufflenet_v2_x0_5(pretrained=False,
                                   num_classes=num_classes).cuda()

    elif Model == "shufflenet_v2_x1_0":
        model = shufflenet_v2_x1_0(pretrained=False,
                                   num_classes=num_classes).cuda()

    elif Model == "shufflenet_v2_x1_5":
        model = shufflenet_v2_x1_5(pretrained=False,
                                   num_classes=num_classes).cuda()

    elif Model == "shufflenet_v2_x1_5":
        model = shufflenet_v2_x2_0(pretrained=False,
                                   num_classes=num_classes).cuda()

    elif "efficientnet" in Model:
        model = EfficientNet.from_pretrained(Model,
                                             num_classes=num_classes,
                                             pretrained=False).cuda()

    elif Model == "inception_v3":
        model = inception_v3(pretrained=False, num_classes=num_classes).cuda()

    elif Model == "mnasnet0_5":
        model = mnasnet0_5(pretrained=False, num_classes=num_classes).cuda()

    elif Model == "vgg11":
        model = vgg11(pretrained=False, num_classes=num_classes).cuda()

    elif Model == "vgg11_bn":
        model = vgg11_bn(pretrained=False, num_classes=num_classes).cuda()

    elif Model == "vgg19":
        model = vgg19(pretrained=False, num_classes=num_classes).cuda()

    elif Model == "densenet121":
        model = densenet121(pretrained=False).cuda()

    elif Model == "ResNeXt29_32x4d":
        model = ResNeXt29_32x4d(num_classes=num_classes).cuda()

    elif Model == "ResNeXt29_2x64d":
        model = ResNeXt29_2x64d(num_classes=num_classes).cuda()

    else:
        print("model error")

    # input = torch.randn(1, 3, 32, 32).cuda()
    # flops, params = profile(model, inputs=(input,))
    # flops, params = clever_format([flops, params], "%.3f")
    # print("------------------------------------------------------------------------")
    # print("                              ",Model)
    # print( "                    flops:", flops, "    params:", params)
    # print("------------------------------------------------------------------------")

    return model
示例#10
0
def main(args):

    random_seed = int(np.random.choice(range(1000), 1))
    np.random.seed(random_seed)
    random.seed(random_seed)
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)
    torch.backends.cudnn.deterministic = True

    arg_which_net = args.network
    arg_dataset = args.dataset
    arg_epoch_start = args.epoch_start
    lr = args.lr
    arg_gpu = args.gpu
    arg_num_gpu = args.n_gpus
    arg_every_n_epoch = args.every_n_epoch   # interval to perform the correction
    arg_epoch_update = args.epoch_update     # the epoch to start correction (warm-up period)
    arg_epoch_interval = args.epoch_interval # interval between two update of A
    noise_level = args.noise_level
    noise_type = args.noise_type             # "uniform", "asymmetric", "none"
    train_val_ratio = 0.9
    which_net = arg_which_net                # "cnn" "resnet18" "resnet34" "preact_resnet18" "preact_resnet34" "preact_resnet101" "pc"
    num_epoch = args.n_epochs                # Total training epochs


    print('Using {}\nTest on {}\nRandom Seed {}\nevery n epoch {}\nStart at epoch {}'.
          format(arg_which_net, arg_dataset, random_seed, arg_every_n_epoch, arg_epoch_start))

    # -- training parameters
    if arg_dataset == 'mnist':
        milestone = [30, 60]
        batch_size = 64
        in_channels = 1
    elif arg_dataset == 'cifar10':
        milestone = [60, 180]
        batch_size = 128
        in_channels = 1
    elif arg_dataset == 'cifar100':
        milestone = [60, 180]
        batch_size = 128
        in_channels = 1
    elif arg_dataset == 'pc':
        milestone = [30, 60]
        batch_size = 128

    start_epoch = 0
    num_workers = 1

    #gamma = 0.5

    # -- specify dataset
    # data augmentation
    if arg_dataset == 'cifar10':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
    elif arg_dataset == 'cifar100':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.507, 0.487, 0.441), (0.507, 0.487, 0.441)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.507, 0.487, 0.441), (0.507, 0.487, 0.441)),
        ])
    elif arg_dataset == 'mnist':
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
        ])
    else:
        transform_train = None
        transform_test = None

    if arg_dataset == 'cifar10':
        trainset = CIFAR10(root='./data', split='train', train_ratio=train_val_ratio, trust_ratio=0, download=True, transform=transform_train)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, worker_init_fn=_init_fn)
        valset = CIFAR10(root='./data', split='val', train_ratio=train_val_ratio, trust_ratio=0, download=True, transform=transform_test)
        valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

        testset = CIFAR10(root='./data', split='test', download=True, transform=transform_test)
        testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

        images_size = [3, 32, 32]
        num_class = 10
        in_channel = 3

    elif arg_dataset == 'cifar100':
        trainset = CIFAR100(root='./data', split='train', train_ratio=train_val_ratio, trust_ratio=0, download=True, transform=transform_train)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers,worker_init_fn=_init_fn)
        valset = CIFAR100(root='./data', split='val', train_ratio=train_val_ratio, trust_ratio=0, download=True, transform=transform_test)
        valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

        testset = CIFAR100(root='./data', split='test', download=True, transform=transform_test)
        testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

        num_class = 100
        in_channel = 3

    elif arg_dataset == 'mnist':
        trainset = MNIST(root='./data', split='train', train_ratio=train_val_ratio, download=True, transform=transform_train)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers,worker_init_fn=_init_fn)
        valset = MNIST(root='./data', split='val', train_ratio=train_val_ratio, download=True, transform=transform_test)
        valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

        testset = MNIST(root='./data', split='test', download=True, transform=transform_test)
        testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

        num_class = 10
        in_channel = 1

    elif arg_dataset == 'pc':
        trainset = ModelNet40(split='train', train_ratio=train_val_ratio, num_ptrs=1024, random_jitter=True)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, worker_init_fn=_init_fn, drop_last=True)
        valset = ModelNet40(split='val', train_ratio=train_val_ratio, num_ptrs=1024)
        valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

        testset = ModelNet40(split='test', num_ptrs=1024)
        testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

        num_class = 40

    print('train data size:', len(trainset))
    print('validation data size:', len(valset))
    print('test data size:', len(testset))

    eps = 1e-6               # This is the epsilon used to soft the label (not the epsilon in the paper)
    ntrain = len(trainset)

    # -- generate noise --
    # y_train is ground truth labels, we should not have any access to  this after the noisy labels are generated
    # algorithm after y_tilde is generated has nothing to do with y_train
    y_train = trainset.get_data_labels()
    y_train = np.array(y_train)

    noise_y_train = None
    keep_indices = None
    p = None

    if(noise_type == 'none'):
        pass
    else:
        if noise_type == "uniform":
            noise_y_train, p, keep_indices = noisify_with_P(y_train, nb_classes=num_class, noise=noise_level, random_state=random_seed)
            trainset.update_corrupted_label(noise_y_train)
            noise_softlabel = torch.ones(ntrain, num_class)*eps/(num_class-1)
            noise_softlabel.scatter_(1, torch.tensor(noise_y_train.reshape(-1, 1)), 1-eps)
            trainset.update_corrupted_softlabel(noise_softlabel)

            print("apply uniform noise")
        else:
            if arg_dataset == 'cifar10':
                noise_y_train, p, keep_indices = noisify_cifar10_asymmetric(y_train, noise=noise_level, random_state=random_seed)
            elif arg_dataset == 'cifar100':
                noise_y_train, p, keep_indices = noisify_cifar100_asymmetric(y_train, noise=noise_level, random_state=random_seed)
            elif arg_dataset == 'mnist':
                noise_y_train, p, keep_indices = noisify_mnist_asymmetric(y_train, noise=noise_level, random_state=random_seed)
            elif arg_dataset == 'pc':
                noise_y_train, p, keep_indices = noisify_modelnet40_asymmetric(y_train, noise=noise_level,
                                                                               random_state=random_seed)
            trainset.update_corrupted_label(noise_y_train)
            noise_softlabel = torch.ones(ntrain, num_class) * eps / (num_class - 1)
            noise_softlabel.scatter_(1, torch.tensor(noise_y_train.reshape(-1, 1)), 1 - eps)
            trainset.update_corrupted_softlabel(noise_softlabel)

            print("apply asymmetric noise")
        print("clean data num:", len(keep_indices))
        print("probability transition matrix:\n{}".format(p))

    # -- create log file
    file_name = '[' + arg_dataset + '_' + which_net + ']' \
                + 'type:' + noise_type + '_' + 'noise:' + str(noise_level) + '_' \
                + '_' + 'start:' + str(arg_epoch_start) + '_' \
                + 'every:' + str(arg_every_n_epoch) + '_'\
                + 'time:' + str(datetime.datetime.now()) + '.txt'
    log_dir = check_folder('new_logs/logs_txt_' + str(random_seed))
    file_name = os.path.join(log_dir, file_name)
    saver = open(file_name, "w")

    saver.write('noise type: {}\nnoise level: {}\nwhen_to_apply_epoch: {}\n'.format(
        noise_type, noise_level, arg_epoch_start))

    if noise_type != 'none':
        saver.write('total clean data num: {}\n'.format(len(keep_indices)))
        saver.write('probability transition matrix:\n{}\n'.format(p))
    saver.flush()

    # -- set network, optimizer, scheduler, etc
    if which_net == "cnn":
        net_trust = CNN9LAYER(input_channel=in_channel, n_outputs=num_class)
        net = CNN9LAYER(input_channel=in_channel, n_outputs=num_class)
        net.apply(weight_init)
        feature_size = 128
    elif which_net == 'resnet18':
        net_trust = resnet18(in_channel=in_channel, num_classes=num_class)
        net = resnet18(in_channel=in_channel, num_classes=num_class)
        feature_size = 512
    elif which_net == 'resnet34':
        net_trust = resnet34(in_channel=in_channel, num_classes=num_class)
        net = resnet34(in_channel=in_channel, num_classes=num_class)
        feature_size = 512
    elif which_net == 'preact_resnet18':
        net_trust = preact_resnet18(num_classes=num_class, num_input_channels=in_channel)
        net = preact_resnet18(num_classes=num_class, num_input_channels=in_channel)
        feature_size = 256
    elif which_net == 'preact_resnet34':
        net_trust = preact_resnet34(num_classes=num_class, num_input_channels=in_channel)
        net = preact_resnet34(num_classes=num_class, num_input_channels=in_channel)
        feature_size = 256
    elif which_net == 'preact_resnet101':
        net_trust = preact_resnet101()
        net = preact_resnet101()
        feature_size = 256
    elif which_net == 'pc':
        net_trust = PointNetCls(k=num_class)
        net = PointNetCls(k=num_class)
        feature_size = 256
    else:
        ValueError('Invalid network!')

    opt_gpus = [i for i in range(arg_gpu, arg_gpu+int(arg_num_gpu))]
    if len(opt_gpus) > 1:
        print("Using ", len(opt_gpus), " GPUs")
        os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(x) for x in opt_gpus)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)

    if len(opt_gpus) > 1:
        net_trust = torch.nn.DataParallel(net_trust)
        net = torch.nn.DataParallel(net)
    net_trust.to(device)
    net.to(device)
    # net.apply(conv_init)

    # ------------------------- Initialize Setting ------------------------------
    best_acc = 0
    best_epoch = 0
    patience = 50
    no_improve_counter = 0

    A = 1/num_class*torch.ones(ntrain, num_class, num_class, requires_grad=False).float().to(device)
    h = np.zeros([ntrain, num_class])

    criterion_1 = nn.NLLLoss()
    pred_softlabels = np.zeros([ntrain, arg_every_n_epoch, num_class], dtype=np.float)

    train_acc_record = []
    clean_train_acc_record = []
    noise_train_acc_record = []
    val_acc_record = []
    recovery_record = []
    noise_ytrain = copy.copy(noise_y_train)
    #noise_ytrain = torch.tensor(noise_ytrain).to(device)

    cprint("================  Clean Label...  ================", "yellow")
    for epoch in range(num_epoch):  # Add some modification here

        train_correct = 0
        train_loss = 0
        train_total = 0
        delta = 1.2 + 0.02*max(epoch - arg_epoch_update + 1, 0)

        clean_train_correct = 0
        noise_train_correct = 0

        optimizer_trust = RAdam(net_trust.parameters(),
                                     lr=learning_rate(lr, epoch),
                                     weight_decay=5e-4)
        #optimizer_trust = optim.SGD(net_trust.parameters(), lr=learning_rate(lr, epoch), weight_decay=5e-4,
        #                            nesterov=True, momentum=0.9)

        net_trust.train()

        # Train with noisy data
        for i, (images, labels, softlabels, indices) in enumerate(tqdm(trainloader, ncols=100, ascii=True)):
            if images.size(0) == 1:  # when batch size equals 1, skip, due to batch normalization
                continue

            images, labels, softlabels = images.to(device), labels.to(device), softlabels.to(device)
            outputs, features = net_trust(images)
            log_outputs = torch.log_softmax(outputs, 1).float()

            # arg_epoch_start : epoch start to introduce loss retro
            # arg_epoch_interval : epochs between two updating of A
            if epoch in [arg_epoch_start-1, arg_epoch_start+arg_epoch_interval-1]:
                h[indices] = log_outputs.detach().cpu()
            normal_outputs = torch.softmax(outputs, 1)

            if epoch >= arg_epoch_start: # use loss_retro + loss_ce
                A_batch = A[indices].to(device)
                loss = sum([-A_batch[i].matmul(softlabels[i].reshape(-1, 1).float()).t().matmul(log_outputs[i])
                            for i in range(len(indices))]) / len(indices) + \
                       criterion_1(log_outputs, labels)
            else: # use loss_ce
                loss = criterion_1(log_outputs, labels)

            optimizer_trust.zero_grad()
            loss.backward()
            optimizer_trust.step()

            #arg_every_n_epoch : rolling windows to get eta_tilde
            if epoch >= (arg_epoch_update - arg_every_n_epoch):
                pred_softlabels[indices, epoch % arg_every_n_epoch, :] = normal_outputs.detach().cpu().numpy()

            train_loss += loss.item()
            train_total += images.size(0)
            _, predicted = outputs.max(1)
            train_correct += predicted.eq(labels).sum().item()

            # For monitoring purpose, comment the following line out if you have your own dataset that doesn't have ground truth label
            train_label_clean = torch.tensor(y_train)[indices].to(device)
            train_label_noise = torch.tensor(noise_ytrain[indices]).to(device)
            clean_train_correct += predicted.eq(train_label_clean).sum().item()
            noise_train_correct += predicted.eq(train_label_noise).sum().item() # acc wrt the original noisy labels

        train_acc = train_correct / train_total * 100
        clean_train_acc = clean_train_correct/train_total*100
        noise_train_acc = noise_train_correct/train_total*100
        print(" Train Epoch: [{}/{}] \t Training Acc wrt Corrected {:.3f} \t Train Acc wrt True {:.3f} \t Train Acc wrt Noise {:.3f}".
              format(epoch, num_epoch, train_acc, clean_train_acc, noise_train_acc))

        # updating A
        if epoch in [arg_epoch_start - 1, arg_epoch_start + 40 - 1]:
            cprint("+++++++++++++++++ Updating A +++++++++++++++++++", "magenta")
            unsolved = 0
            infeasible = 0
            y_soft = trainset.get_data_softlabel()

            with torch.no_grad():
                for i in tqdm(range(ntrain), ncols=100, ascii=True):
                    try:
                        result, A_opt = updateA(y_soft[i], h[i], rho=0.9)
                    except:
                        A[i] = A[i]
                        unsolved += 1
                        continue

                    if (result == np.inf):
                        A[i] = A[i]
                        infeasible += 1
                    else:
                        A[i] = torch.tensor(A_opt)
            print(A[0])
            print("Unsolved points: {} | Infeasible points: {}".format(unsolved, infeasible))

        # applying lRT scheme
        # args_epoch_update : epoch to update labels
        if epoch >= arg_epoch_update:
            y_tilde = trainset.get_data_labels()
            pred_softlabels_bar = pred_softlabels.mean(1)
            clean_labels, clean_softlabels = lrt_flip_scheme(pred_softlabels_bar, y_tilde, delta)
            trainset.update_corrupted_softlabel(clean_softlabels)
            trainset.update_corrupted_label(clean_softlabels.argmax(1))

        # validation
        if not (epoch % 5):
            val_total = 0
            val_correct = 0
            net_trust.eval()

            with torch.no_grad():
                for i, (images, labels, _, _) in enumerate(valloader):
                    images, labels = images.to(device), labels.to(device)

                    outputs, _ = net_trust(images)

                    val_total += images.size(0)
                    _, predicted = outputs.max(1)
                    val_correct += predicted.eq(labels).sum().item()

            val_acc = val_correct / val_total * 100

            train_acc_record.append(train_acc)
            val_acc_record.append(val_acc)
            clean_train_acc_record.append(clean_train_acc)
            noise_train_acc_record.append(noise_train_acc)

            recovery_acc = np.sum(trainset.get_data_labels() == y_train) / ntrain
            recovery_record.append(recovery_acc)

            if val_acc > best_acc:
                best_acc = val_acc
                best_epoch = epoch
                no_improve_counter = 0
            else:
                no_improve_counter += 1
                if no_improve_counter >= patience:
                    print('>> No improvement for {} epochs. Stop at epoch {}'.format(patience, epoch))
                    saver.write('>> No improvement for {} epochs. Stop at epoch {}'.format(patience, epoch))
                    saver.write('>> val epoch: {}\n>> current accuracy: {}%\n'.format(epoch, val_acc))
                    saver.write('>> best accuracy: {}\tbest epoch: {}\n\n'.format(best_acc, best_epoch))
                    break

            cprint('val accuracy: {}'.format(val_acc), 'cyan')
            cprint('>> best accuracy: {}\n>> best epoch: {}\n'.format(best_acc, best_epoch), 'green')
            cprint('>> final recovery rate: {}\n'.format(recovery_acc),
                   'green')
            saver.write('>> val epoch: {}\n>> current accuracy: {}%\n'.format(epoch, val_acc))
            saver.write("outputs: {}\n".format(normal_outputs))
            saver.write('>> best accuracy: {}\tbest epoch: {}\n\n'.format(best_acc, best_epoch))
            saver.write(
                '>> final recovery rate: {}%\n'.format(np.sum(trainset.get_data_labels() == y_train) / ntrain * 100))
            saver.flush()

    # If want to train the neural network again with corrected labels and original loss_ce again
    # set args.two_stage to True
    print("Use Two-Stage Model {}".format(args.two_stage))
    if args.two_stage==True:
       criterion_2 = nn.NLLLoss()
       best_acc = 0
       best_epoch = 0
       patience = 50
       no_improve_counter = 0

       cprint("================ Normal Training  ================", "yellow")
       for epoch in range(num_epoch):  # Add some modification here

           train_correct = 0
           train_loss = 0
           train_total = 0

           optimizer_trust = optim.SGD(net.parameters(), momentum=0.9, nesterov=True,
                                       lr=learning_rate(lr, epoch),
                                       weight_decay=5e-4)

           net.train()

           for i, (images, labels, softlabels, indices) in enumerate(tqdm(trainloader, ncols=100, ascii=True)):
               if images.size(0) == 1:  # when batch size equals 1, skip, due to batch normalization
                   continue

               images, labels, softlabels = images.to(device), labels.to(device), softlabels.to(device)
               outputs, features = net(images)
               log_outputs = torch.log_softmax(outputs, 1).float()

               loss = criterion_2(log_outputs, labels)

               optimizer_trust.zero_grad()
               loss.backward()
               optimizer_trust.step()

               train_loss += loss.item()
               train_total += images.size(0)
               _, predicted = outputs.max(1)
               train_correct += predicted.eq(labels).sum().item()

           train_acc = train_correct / train_total * 100
           print(" Train Epoch: [{}/{}] \t Training Accuracy {}%".format(epoch, num_epoch, train_acc))

           if not (epoch % 5):

               val_total = 0
               val_correct = 0
               net_trust.eval()

               with torch.no_grad():
                   for i, (images, labels, _, _) in enumerate(valloader):
                       images, labels = images.to(device), labels.to(device)
                       outputs, _ = net(images)

                   val_total += images.size(0)
                   _, predicted = outputs.max(1)
                   val_correct += predicted.eq(labels).sum().item()

               val_acc = val_correct / val_total * 100

               if val_acc > best_acc:
                   best_acc = val_acc
                   best_epoch = epoch
                   no_improve_counter = 0
               else:
                   no_improve_counter += 1
                   if no_improve_counter >= patience:
                       print('>> No improvement for {} epochs. Stop at epoch {}'.format(patience, epoch))
                       saver.write('>> No improvement for {} epochs. Stop at epoch {}'.format(patience, epoch))
                       saver.write('>> val epoch: {}\n>> current accuracy: {}%\n'.format(epoch, val_acc))
                       saver.write('>> best accuracy: {}\tbest epoch: {}\n\n'.format(best_acc, best_epoch))
                       break

               cprint('val accuracy: {}'.format(val_acc), 'cyan')
               cprint('>> best accuracy: {}\n>> best epoch: {}\n'.format(best_acc, best_epoch), 'green')
               cprint('>> final recovery rate: {}\n'.format(np.sum(trainset.get_data_labels() == y_train) / ntrain), 'green')

    cprint("================  Start Testing  ================", "yellow")
    test_total = 0
    test_correct = 0

    net_trust.eval()
    for i, (images, labels, softlabels, indices) in enumerate(testloader):
        if images.shape[0] == 1:
            continue

        images, labels = images.to(device), labels.to(device)
        outputs, _ = net_trust(images)

        test_total += images.shape[0]
        test_correct += outputs.argmax(1).eq(labels).sum().item()

    test_acc = test_correct/test_total*100
    print("Final test accuracy {} %".format(test_correct/test_total*100))

    return test_acc
示例#11
0
    def __init__(self,
                 ver_dim,
                 seg_dim,
                 fcdim=256,
                 s8dim=128,
                 s4dim=64,
                 s2dim=32,
                 raw_dim=32,
                 inp_dim=3):
        super(Resnet18_8s, self).__init__()

        # Load the pretrained weights, remove avg pool
        # layer and get the output stride of 8
        resnet18_8s = resnet18(
            inp_dim=inp_dim,
            fully_conv=True,
            pretrained=True,
            output_stride=8,
            remove_avg_pool_layer=True,
        )

        self.ver_dim = ver_dim
        self.seg_dim = seg_dim

        # Randomly initialize the 1x1 Conv scoring layer
        resnet18_8s.fc = nn.Sequential(
            nn.Conv2d(resnet18_8s.inplanes, fcdim, 3, 1, 1, bias=False),
            nn.BatchNorm2d(fcdim),
            nn.ReLU(True),
        )
        self.resnet18_8s = resnet18_8s

        # x8s->128
        self.conv8s = nn.Sequential(
            nn.Conv2d(128 + fcdim, s8dim, 3, 1, 1, bias=False),
            nn.BatchNorm2d(s8dim),
            nn.LeakyReLU(0.1, True),
        )
        self.up8sto4s = nn.UpsamplingBilinear2d(scale_factor=2)

        # x4s->64
        self.conv4s = nn.Sequential(
            nn.Conv2d(64 + s8dim, s4dim, 3, 1, 1, bias=False),
            nn.BatchNorm2d(s4dim),
            nn.LeakyReLU(0.1, True),
        )
        self.up4sto2s = nn.UpsamplingBilinear2d(scale_factor=2)

        # x2s->64
        self.conv2s = nn.Sequential(
            nn.Conv2d(64 + s4dim, s2dim, 3, 1, 1, bias=False),
            nn.BatchNorm2d(s2dim),
            nn.LeakyReLU(0.1, True),
        )
        self.up2storaw = nn.UpsamplingBilinear2d(scale_factor=2)

        self.convraw = nn.Sequential(
            nn.Conv2d(inp_dim + s2dim, raw_dim, 3, 1, 1, bias=False),
            nn.BatchNorm2d(raw_dim),
            nn.LeakyReLU(0.1, True),
            nn.Conv2d(raw_dim, self.seg_dim + self.ver_dim, 1, 1),
        )
示例#12
0
 def __init__(self,mode):
     super(buildlModel, self).__init__()
     self.backbone=resnet18(pretrained=True)
     self.neck=AdjustLayer(256,256)
     self.rpn=RPN(256,256,1)
     self.mode=mode
示例#13
0
    def __init__(self):

        self.img_size = constant.IMG_SIZE

        self.noise_dim = config.generator["noise_dim"]
        self.generator = nn.DataParallel(
            TP_GAN.Generator(
                noise_dim=config.generator["noise_dim"],
                encode_feature_dim=config.generator["encode_feature_dim"],
                encode_predict_dim=config.generator["encode_predict_dim"],
                use_batchnorm=config.generator["use_batchnorm"],
                use_residual_block=config.generator["use_residual_block"]))
        self.discriminator = nn.DataParallel(
            TP_GAN.Discriminator(config.discriminator["use_batchnorm"]))
        self.lr = config.settings["init_lr"]
        self.optimizer_G = torch.optim.Adam(self.generator.parameters(),
                                            lr=self.lr)
        self.optimizer_D = torch.optim.Adam(self.discriminator.parameters(),
                                            lr=self.lr)
        self.lr_scheduler_G = torch.optim.lr_scheduler.MultiStepLR(
            self.optimizer_G, milestones=[], gamma=0.1)
        self.lr_scheduler_D = torch.optim.lr_scheduler.MultiStepLR(
            self.optimizer_D, milestones=[], gamma=0.1)
        self.epoch = config.settings["epoch"]
        self.batch_size = config.settings["batch_size"]
        self.train_dataloader = torch.utils.data.DataLoader(
            dataset.train_dataloader(config.path["img_test"],
                                     batch_size=self.batch_size,
                                     shuffle=True,
                                     num_workers=8,
                                     pin_memory=True))
        self.feature_extract_network_resnet18 = nn.DataParallel(
            resnet.resnet18(
                pretrained=True,
                num_classes=config.generator["encode_predict_dim"]))
        self.l1_loss = nn.L1Loss()
        self.MSE_loss = nn.MSELoss()
        self.cross_entropy = nn.CrossEntropyLoss()

        if USE_CUDA:
            self.generator = self.generator.cuda()
            self.discriminator = self.discriminator.cuda()
            self.l1_loss = self.l1_loss.cuda()
            self.MSE_loss = self.MSE_loss.cuda()
            self.cross_entropy = self.cross_entropy.cuda()

        def backward_D(self, img_predict, img_frontal):
            set_requires_grad(self.discriminator, True)
            adversarial_D_loss = -torch.mean(
                discriminator(img_frontal)) + torch.mean(
                    discriminator(img_predict))
            factor = torch.rand(img_frontal.shape[0], 1, 1,
                                1).expand(img_frontal.size())  #
            interpolated_value = Variable(factor * img_predict.data +
                                          (1.0 - factor) * img_frontal.data,
                                          requires_grad=True)
            output = self.discriminator(interpolated_value)
            # WGAN-GP loss
            gradient = torch.autograd.grad(outputs=output,
                                           inputs=interpolated_value,
                                           grad_outputs=torch.ones(
                                               output.size()),
                                           retain_graph=True,
                                           create_graph=True,
                                           only_inputs=True)[0]
            gradient = gradient.view(output.shape[0], -1)
            if USE_CUDA:
                gradient = gradient.cuda()
            gp_loss = torch.mean((torch.norm(gradient, p=2) - 1)**2)
            loss_D = adversarial_D_loss + config.loss[
                "weight_gradient_penalty"] * gp_loss

            self.optimizer_D.zero_grad()
            loss_D.backward()
            self.optimizer_D.step()

            return loss_D

        def backward_G(self, inputs, outputs):
            set_requires_grad(self.discriminator, False)
            # adversarial loss
            adversarial_G_loss = -torch.mean(
                self.discriminator(outputs["img128_predict"]))
            # pixel wise loss
            pixelwise_loss_128 = self.l1_loss(inputs["img128_frontal"],
                                              outputs["img128_predict"])
            pixelwise_loss_64 = self.l1_loss(inputs["img64_frontal"],
                                             outputs["img64_predict"])
            pixelwise_loss_32 = self.l1_loss(inputs["img32_frontal"],
                                             outputs["img32_predict"])
            pixelwise_global_loss = config.loss["pixelwise_weight_128"] * pixelwise_weight_128 + \
             config.loss["pixelwise_weight_64"] * pixelwise_loss_64 + config.loss["pixelwise_weight_32"] * pixelwise_loss_32

            left_eye_loss = self.l1_loss(inputs["left_eye_frontal"],
                                         outputs["left_eye_predict"])
            right_eye_loss = self.l1_loss(inputs["right_eye_frontal"],
                                          outputs["right_eye_predict"])
            nose_loss = self.l1_loss(inputs["nose_frontal"],
                                     outputs["nose_predict"])
            mouth_loss = self.l1_loss(inputs["mouth_frontal"],
                                      outputs["mouth_predict"])
            pixel_local_loss = left_eye_loss + right_eye_loss + nose_loss + mouth_loss
            # symmetry loss
            img128 = outputs["img128_predict"]
            img64 = outputs["img64_predict"]
            img32 = outputs["img32_predict"]
            inv_idx128 = torch.arange(img128.size()[3] - 1, -1, -1).long()
            inv_idx64 = torch.arange(img64.size()[3] - 1, -1, -1).long()
            inv_idx32 = torch.arange(img32.size()[3] - 1, -1, -1).long()
            if USE_CUDA:
                inv_idx128 = inv_idx128.cuda()
                inv_idx64 = inv_idx64.cuda()
                inv_idx32 = inv_idx32.cuda()
            img128_flip = img128.index_select(3, Variable(inv_idx128))
            img64_flip = img64.index_select(3, Variable(inv_idx64))
            img32_flip = img32.index_select(3, Variable(inv_idx32))
            img128_flip.detach_()
            img64_flip.detach_()
            img32_flip.detach_()

            symmetry_loss_128 = self.l1_loss(img128, img128_flip)
            symmetry_loss_64 = self.l1_loss(img64, img64_flip)
            symmetry_loss_32 = self.l1_loss(img32, img32_flip)
            symmetry_loss = config.loss["symmetry_weight_128"] * symmetry_weight_128 + \
             config.loss["symmetry_weight_64"] * symmetry_weight_64 + config.loss["symmetry_weight_32"] * symmetry_weight_32

            # identity preserving loss
            feature_frontal = self.feature_extract_network_resnet18(
                inputs["img128_frontal"])
            feature_predict = self.feature_extract_network_resnet18(
                outputs["img128_predict"])
            identity_preserving_loss = self.MSE_loss(feature_frontal,
                                                     feature_predict)

            # total variation loss for regularization
            img128 = outputs["img128_predict"]
            total_variation_loss = torch.mean(
                torch.abs(img128[:, :, :-1, :] -
                          img128[:, :, 1:, :])) + torch.mean(
                              torch.abs(img128[:, :, :, :-1], img128[:, :, :,
                                                                     1:]))
            # cross entropy loss
            cross_entropy_loss = self.cross_entropy_loss(
                outputs["encode_predict"], inputs["label"])

            # synthesized loss
            synthesized_loss = config.loss["pixelwise_global_weight"] * pixelwise_global_loss + config.loss["pixel_local_weight"] * pixel_local_loss + \
             config.loss["symmetry_weight"] * symmetry_loss + config.loss["adversarial_G_weight"] * adversarial_G_loss + \
              config.loss["identity_preserving_weight"] * identity_preserving_loss + config.loss["total_variation_weight"] * total_variation_loss
            loss_G = synthesized_loss + config.loss[
                "cross_entropy_weight"] * cross_entropy_loss

            self.optimizer_G.zero_grad()
            loss_G.backward()
            optimizer_G.step()

        def train(self):
            self.generator.train()
            self.discriminator.train()
            self.feature_extract_network_resnet18.eval()
            for cur_epoch in range(self.epoch):
                print("\ncurrent epoch number: %d" % cur_epoch)
                self.lr_scheduler_G.step()
                self.lr = self.lr_scheduler_G.get_lr()[0]
                self.lr_scheduler_D.step()
                self.lr = self.lr_scheduler_D.get_lr()[0]
                for batch_index, inputs in enumerate(self.train_dataloader):
                    noise = Variable(
                        torch.FloatTensor(
                            np.random.uniform(
                                -1, 1, (self.batch_size, self.noise_dim))))
                    for k, v in inputs:
                        if USE_CUDA:
                            v = v.cuda()
                        v = Variable(v, requires_grad=False)
                    if USE_CUDA:
                        noise = noise.cuda()
                    generator_output = generator(inputs["img"],
                                                 inputs["left_eye"],
                                                 inputs["right_eye"],
                                                 inputs["nose"],
                                                 inputs["mouth"], noise)
                    # backward
                    backward_D(generator_output["img128_predict"].detach(),
                               inputs["img128_frontal"])
                    backward_G(generator_output, inputs)
            save_model()

        def save_model(self):
            torch.save(self.generator.cpu().state_dict(),
                       config.path["generator__path"])
            torch.save(self.discriminator.cpu().state_dict(),
                       config.path["discriminator_save_path"])