示例#1
0
            # get inputs and labels from data
            inputs, _, img = data
            inputs = inputs.to(device)

            # model output
            outputs = model(inputs)
            predicted = torch.max(outputs.data, 1)[1]

            # convert to list
            predicted = predicted.tolist()
            img = list(img)

            for id, label in zip(img, predicted):
                df = df.append({'id': id, 'label': label}, ignore_index=True)
    return df


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

_, _, test_loader = load_data()

best_epoch = 30
best = torch.load('./save/resnet50/train/epoch-{:02d}.pth'.format(best_epoch),
                  map_location=device)
resnet50 = resnet50()
resnet50.load_state_dict(best)

df = model_test(resnet50, test_loader, device)
df.to_csv('resnet50_result.csv', index=False)
示例#2
0
文件: train.py 项目: xpwu95/LDL
def trainval_test(cross_val_index, sigma, lam):

    TRAIN_FILE = '/home/ubuntu5/wxp/datasets/acne4/VOCdevkit2007/VOC2007/ImageSets/Main/NNEW_trainval_' + cross_val_index + '.txt'
    TEST_FILE = '/home/ubuntu5/wxp/datasets/acne4/VOCdevkit2007/VOC2007/ImageSets/Main/NNEW_test_' + cross_val_index + '.txt'

    normalize = transforms.Normalize(mean=[0.45815152, 0.361242, 0.29348266],
                                     std=[0.2814769, 0.226306, 0.20132513])

    dset_train = dataset_processing.DatasetProcessing(
        DATA_PATH, TRAIN_FILE, transform=transforms.Compose([
                transforms.Scale((256, 256)),
                transforms.RandomCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                RandomRotate(rotation_range=20),
                normalize,
            ]))

    dset_test = dataset_processing.DatasetProcessing(
        DATA_PATH, TEST_FILE, transform=transforms.Compose([
                transforms.Scale((224, 224)),
                transforms.ToTensor(),
                normalize,
            ]))

    train_loader = DataLoader(dset_train,
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              num_workers=NUM_WORKERS,
                              pin_memory=True)

    test_loader = DataLoader(dset_test,
                             batch_size=BATCH_SIZE_TEST,
                             shuffle=False,
                             num_workers=NUM_WORKERS,
                             pin_memory=True)

    cnn = resnet50().cuda()
    cudnn.benchmark = True

    params = []
    new_param_names = ['fc', 'counting']
    for key, value in dict(cnn.named_parameters()).items():
        if value.requires_grad:
            if any(i in key for i in new_param_names):
                params += [{'params': [value], 'lr': LR * 1.0, 'weight_decay': 5e-4}]
            else:
                params += [{'params': [value], 'lr': LR * 1.0, 'weight_decay': 5e-4}]

    optimizer = torch.optim.SGD(params, momentum=0.9)  #

    loss_func = nn.CrossEntropyLoss().cuda()
    kl_loss_1 = nn.KLDivLoss().cuda()
    kl_loss_2 = nn.KLDivLoss().cuda()
    kl_loss_3 = nn.KLDivLoss().cuda()

    def adjust_learning_rate_new(optimizer, decay=0.5):
        """Sets the learning rate to the initial LR decayed by 0.5 every 20 epochs"""
        for param_group in optimizer.param_groups:
            param_group['lr'] = decay * param_group['lr']

    # training and testing
    start = timer()
    test_acc_his = 0.7
    test_mae_his = 8
    test_mse_his = 18
    for epoch in range(lr_steps[-1]):#(EPOCH):#

        if epoch in lr_steps:
            adjust_learning_rate_new(optimizer, 0.5)
        # scheduler.step(epoch)

        losses_cls = AverageMeter()
        losses_cou = AverageMeter()
        losses_cou2cls = AverageMeter()
        losses = AverageMeter()
        # '''
        cnn.train()
        for step, (b_x, b_y, b_l) in enumerate(train_loader):   # gives batch data, normalize x when iterate train_loader

            b_x = b_x.cuda()
            b_l = b_l.numpy()

            # generating ld
            b_l = b_l - 1
            ld = genLD(b_l, sigma, 'klloss', 65)
            ld_4 = np.vstack((np.sum(ld[:, :5], 1), np.sum(ld[:, 5:20], 1), np.sum(ld[:, 20:50], 1), np.sum(ld[:, 50:], 1))).transpose()
            ld = torch.from_numpy(ld).cuda().float()
            ld_4 = torch.from_numpy(ld_4).cuda().float()

            # train
            cnn.train()

            cls, cou, cou2cls = cnn(b_x, None)nn output
            loss_cls = kl_loss_1(torch.log(cls), ld_4) * 4.0
            loss_cou = kl_loss_2(torch.log(cou), ld) * 65.0
            loss_cls_cou = kl_loss_3(torch.log(cou2cls), ld_4) * 4.0
            loss = (loss_cls + loss_cls_cou) * 0.5 * lam + loss_cou * (1.0 - lam)
            optimizer.zero_grad()           # clear gradients for this training step
            loss.backward()                 # backpropagation, compute gradients
            optimizer.step()                # apply gradients

            losses_cls.update(loss_cls.item(), b_x.size(0))
            losses_cou.update(loss_cou.item(), b_x.size(0))
            losses_cou2cls.update(loss_cls_cou.item(), b_x.size(0))
            losses.update(loss.item(), b_x.size(0))
        message = '%s %6.0f | %0.3f | %0.3f | %0.3f | %0.3f | %s\n' % ( \
                "train", epoch,
                losses_cls.avg,
                losses_cou.avg,
                losses_cou2cls.avg,
                losses.avg,
                time_to_str((timer() - start), 'min'))
        # print(message)
        log.write(message)
        # '''
        if epoch >= 9:
            with torch.no_grad():
                test_loss = 0
                test_corrects = 0
                y_true = np.array([])
                y_pred = np.array([])
                y_pred_m = np.array([])
                l_true = np.array([])
                l_pred = np.array([])
                cnn.eval()
                for step, (test_x, test_y, test_l) in enumerate(test_loader):   # gives batch data, normalize x when iterate train_loader

                    test_x = test_x.cuda()
                    test_y = test_y.cuda()

                    y_true = np.hstack((y_true, test_y.data.cpu().numpy()))
                    l_true = np.hstack((l_true, test_l.data.cpu().numpy()))

                    cnn.eval()

                    cls, cou, cou2cls = cnn(test_x, None)

                    loss = loss_func(cou2cls, test_y)
                    test_loss += loss.data

                    _, preds_m = torch.max(cls + cou2cls, 1)
                    _, preds = torch.max(cls, 1)
                    # preds = preds.data.cpu().numpy()
                    y_pred = np.hstack((y_pred, preds.data.cpu().numpy()))
                    y_pred_m = np.hstack((y_pred_m, preds_m.data.cpu().numpy()))

                    _, preds_l = torch.max(cou, 1)
                    preds_l = (preds_l + 1).data.cpu().numpy()
                    # preds_l = cou2cou.data.cpu().numpy()
                    l_pred = np.hstack((l_pred, preds_l))

                    batch_corrects = torch.sum((preds == test_y)).data.cpu().numpy()
                    test_corrects += batch_corrects

                test_loss = test_loss.float() / len(test_loader)
                test_acc = test_corrects / len(test_loader.dataset)#3292  #len(test_loader)
                message = '%s %6.1f | %0.3f | %0.3f\n' % ( \
                        "test ", epoch,
                        test_loss.data,
                        test_acc)

                _, _, pre_se_sp_yi_report = report_precision_se_sp_yi(y_pred, y_true)
                _, _, pre_se_sp_yi_report_m = report_precision_se_sp_yi(y_pred_m, y_true)
                _, MAE, MSE, mae_mse_report = report_mae_mse(l_true, l_pred, y_true)

                if True:
                    log.write(str(pre_se_sp_yi_report) + '\n')
                    log.write(str(pre_se_sp_yi_report_m) + '\n')
                    log.write(str(mae_mse_report) + '\n')
 def get_model(x_input, network):
     if network == 'resnet50':
         return resnet50(x_input,
                         is_training=False,
                         reuse=False,
                         kernel_initializer=None)
     elif network == 'resnet18':
         return resnet18(x_input,
                         is_training=False,
                         reuse=False,
                         kernel_initializer=None)
     elif network == 'resnet34':
         return resnet34(x_input,
                         is_training=False,
                         reuse=False,
                         kernel_initializer=None)
     elif network == 'seresnet50':
         return se_resnet50(x_input,
                            is_training=False,
                            reuse=False,
                            kernel_initializer=None)
     elif network == 'resnet110':
         return resnet110(x_input,
                          is_training=False,
                          reuse=False,
                          kernel_initializer=None)
     elif network == 'seresnet110':
         return se_resnet110(x_input,
                             is_training=False,
                             reuse=False,
                             kernel_initializer=None)
     elif network == 'seresnet152':
         return se_resnet152(x_input,
                             is_training=False,
                             reuse=False,
                             kernel_initializer=None)
     elif network == 'resnet152':
         return resnet152(x_input,
                          is_training=False,
                          reuse=False,
                          kernel_initializer=None)
     elif network == 'seresnet_fixed':
         return get_resnet(x_input,
                           152,
                           type='se_ir',
                           trainable=False,
                           reuse=True)
     elif network == 'densenet121':
         return densenet121(x_input,
                            is_training=False,
                            reuse=False,
                            kernel_initializer=None)
     elif network == 'densenet169':
         return densenet169(x_input,
                            is_training=False,
                            reuse=False,
                            kernel_initializer=None)
     elif network == 'densenet201':
         return densenet201(x_input,
                            is_training=False,
                            reuse=False,
                            kernel_initializer=None)
     elif network == 'densenet161':
         return densenet161(x_input,
                            is_training=False,
                            reuse=False,
                            kernel_initializer=None)
     elif network == 'densenet100bc':
         return densenet100bc(x_input,
                              reuse=True,
                              is_training=False,
                              kernel_initializer=None)
     elif network == 'densenet190bc':
         return densenet190bc(x_input,
                              reuse=True,
                              is_training=False,
                              kernel_initializer=None)
     elif network == 'resnext50':
         return resnext50(x_input,
                          is_training=False,
                          reuse=False,
                          cardinality=32,
                          kernel_initializer=None)
     elif network == 'resnext110':
         return resnext110(x_input,
                           is_training=False,
                           reuse=False,
                           cardinality=32,
                           kernel_initializer=None)
     elif network == 'resnext152':
         return resnext152(x_input,
                           is_training=False,
                           reuse=False,
                           cardinality=32,
                           kernel_initializer=None)
     elif network == 'seresnext50':
         return se_resnext50(x_input,
                             reuse=True,
                             is_training=False,
                             cardinality=32,
                             kernel_initializer=None)
     elif network == 'seresnext110':
         return se_resnext110(x_input,
                              reuse=True,
                              is_training=False,
                              cardinality=32,
                              kernel_initializer=None)
     elif network == 'seresnext152':
         return se_resnext152(x_input,
                              reuse=True,
                              is_training=False,
                              cardinality=32,
                              kernel_initializer=None)
     raise InvalidNetworkName('Network name is invalid!')
    def __init__(self,
                 nbScale,
                 nbIter,
                 tolerance,
                 transform,
                 minSize,
                 segId=1,
                 segFg=True,
                 imageNet=False,
                 scaleR=2):

        ## nb iteration, tolerance, transform
        self.nbIter = nbIter
        self.tolerance = tolerance

        ## resnet 50
        resnet_feature_layers = [
            'conv1', 'bn1', 'relu', 'maxpool', 'layer1', 'layer2', 'layer3'
        ]

        resNetfeat = resnet50()
        featPth = 'model/pretrained/resnet50_moco.pth'
        param = torch.load(featPth)
        state_dict = {
            k.replace("module.", ""): v
            for k, v in param['model'].items()
        }
        msg = 'Loading pretrained model from {}'.format(featPth)
        print(msg)
        resNetfeat.load_state_dict(state_dict)

        resnet_module_list = [
            getattr(resNetfeat, l) for l in resnet_feature_layers
        ]
        last_layer_idx = resnet_feature_layers.index('layer3')
        self.net = torch.nn.Sequential(*resnet_module_list[:last_layer_idx +
                                                           1])

        self.net.cuda()
        self.net.eval()

        ## preprocessing
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        self.toTensor = transforms.ToTensor()
        self.preproc = transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])

        if transform == 'Affine':
            self.Transform = outil.Affine
            self.nbPoint = 3

        else:
            self.Transform = outil.Homography
            self.nbPoint = 4

        self.strideNet = 16
        self.minSize = minSize

        if nbScale == 1:
            self.scaleList = [1]

        else:
            self.scaleList = np.linspace(
                scaleR, 1, nbScale // 2 + 1).tolist() + np.linspace(
                    1, 1 / scaleR, nbScale // 2 + 1).tolist()[1:]
        print(self.scaleList)
示例#5
0
def train(args):
    batch_size = args.batch_size
    epoch = args.epoch
    network = args.network
    opt = args.opt
    train = unpickle(args.train_path)
    test = unpickle(args.test_path)
    train_data = train[b'data']
    test_data = test[b'data']

    x_train = train_data.reshape(train_data.shape[0], 3, 32, 32)
    x_train = x_train.transpose(0, 2, 3, 1)
    y_train = train[b'fine_labels']

    x_test = test_data.reshape(test_data.shape[0], 3, 32, 32)
    x_test = x_test.transpose(0, 2, 3, 1)
    y_test = test[b'fine_labels']

    x_train = norm_images(x_train)
    x_test = norm_images(x_test)

    print('-------------------------------')
    print('--train/test len: ', len(train_data), len(test_data))
    print('--x_train norm: ', compute_mean_var(x_train))
    print('--x_test norm: ', compute_mean_var(x_test))
    print('--batch_size: ', batch_size)
    print('--epoch: ', epoch)
    print('--network: ', network)
    print('--opt: ', opt)
    print('-------------------------------')

    if not os.path.exists('./trans/tran.tfrecords'):
        generate_tfrecord(x_train, y_train, './trans/', 'tran.tfrecords')
        generate_tfrecord(x_test, y_test, './trans/', 'test.tfrecords')

    dataset = tf.data.TFRecordDataset('./trans/tran.tfrecords')
    dataset = dataset.map(parse_function)
    dataset = dataset.shuffle(buffer_size=50000)
    dataset = dataset.batch(batch_size)
    iterator = dataset.make_initializable_iterator()
    next_element = iterator.get_next()

    x_input = tf.placeholder(tf.float32, [None, 32, 32, 3])
    y_input = tf.placeholder(tf.int64, [
        None,
    ])
    y_input_one_hot = tf.one_hot(y_input, 100)
    lr = tf.placeholder(tf.float32, [])

    if network == 'resnet50':
        prob = resnet50(x_input,
                        is_training=True,
                        reuse=False,
                        kernel_initializer=tf.orthogonal_initializer())
    elif network == 'resnet34':
        prob = resnet34(x_input,
                        is_training=True,
                        reuse=False,
                        kernel_initializer=tf.contrib.layers.
                        variance_scaling_initializer())
    elif network == 'resnet18':
        prob = resnet18(x_input,
                        is_training=True,
                        reuse=False,
                        kernel_initializer=tf.contrib.layers.
                        variance_scaling_initializer())
    elif network == 'seresnet50':
        prob = se_resnet50(x_input,
                           is_training=True,
                           reuse=False,
                           kernel_initializer=tf.orthogonal_initializer())
    elif network == 'resnet110':
        prob = resnet110(x_input,
                         is_training=True,
                         reuse=False,
                         kernel_initializer=tf.orthogonal_initializer())
    elif network == 'seresnet110':
        prob = se_resnet110(x_input,
                            is_training=True,
                            reuse=False,
                            kernel_initializer=tf.orthogonal_initializer())
    elif network == 'seresnet152':
        prob = se_resnet152(x_input,
                            is_training=True,
                            reuse=False,
                            kernel_initializer=tf.orthogonal_initializer())
    elif network == 'resnet152':
        prob = resnet152(x_input,
                         is_training=True,
                         kernel_initializer=tf.orthogonal_initializer())
    elif network == 'seresnet_fixed':
        prob = get_resnet(x_input,
                          152,
                          trainable=True,
                          w_init=tf.orthogonal_initializer())
    elif network == 'densenet121':
        prob = densenet121(x_input,
                           reuse=False,
                           is_training=True,
                           kernel_initializer=tf.orthogonal_initializer())
    elif network == 'densenet169':
        prob = densenet169(x_input,
                           reuse=False,
                           is_training=True,
                           kernel_initializer=tf.orthogonal_initializer())
    elif network == 'densenet201':
        prob = densenet201(x_input,
                           reuse=False,
                           is_training=True,
                           kernel_initializer=tf.orthogonal_initializer())
    elif network == 'densenet161':
        prob = densenet161(x_input,
                           reuse=False,
                           is_training=True,
                           kernel_initializer=tf.orthogonal_initializer())
    elif network == 'densenet100bc':
        prob = densenet100bc(x_input,
                             reuse=False,
                             is_training=True,
                             kernel_initializer=tf.orthogonal_initializer())
    elif network == 'densenet190bc':
        prob = densenet190bc(x_input,
                             reuse=False,
                             is_training=True,
                             kernel_initializer=tf.orthogonal_initializer())
    elif network == 'resnext50':
        prob = resnext50(x_input,
                         reuse=False,
                         is_training=True,
                         cardinality=32,
                         kernel_initializer=tf.orthogonal_initializer())
    elif network == 'resnext110':
        prob = resnext110(x_input,
                          reuse=False,
                          is_training=True,
                          cardinality=32,
                          kernel_initializer=tf.orthogonal_initializer())
    elif network == 'resnext152':
        prob = resnext152(x_input,
                          reuse=False,
                          is_training=True,
                          cardinality=32,
                          kernel_initializer=tf.orthogonal_initializer())
    elif network == 'seresnext50':
        prob = se_resnext50(x_input,
                            reuse=False,
                            is_training=True,
                            cardinality=32,
                            kernel_initializer=tf.orthogonal_initializer())
    elif network == 'seresnext110':
        prob = se_resnext110(x_input,
                             reuse=False,
                             is_training=True,
                             cardinality=32,
                             kernel_initializer=tf.orthogonal_initializer())
    elif network == 'seresnext152':
        prob = se_resnext152(x_input,
                             reuse=False,
                             is_training=True,
                             cardinality=32,
                             kernel_initializer=tf.orthogonal_initializer())

    loss = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits(logits=prob,
                                                labels=y_input_one_hot))

    conv_var = [var for var in tf.trainable_variables() if 'conv' in var.name]
    l2_loss = tf.add_n([tf.nn.l2_loss(var) for var in conv_var])
    loss = l2_loss * 5e-4 + loss

    if opt == 'adam':
        opt = tf.train.AdamOptimizer(lr)
    elif opt == 'momentum':
        opt = tf.train.MomentumOptimizer(lr, 0.9)
    elif opt == 'nesterov':
        opt = tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True)

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = opt.minimize(loss)

    logit_softmax = tf.nn.softmax(prob)
    acc = tf.reduce_mean(
        tf.cast(tf.equal(tf.argmax(logit_softmax, 1), y_input), tf.float32))

    #-------------------------------Test-----------------------------------------
    if not os.path.exists('./trans/tran.tfrecords'):
        generate_tfrecord(x_test, y_test, './trans/', 'test.tfrecords')
    dataset_test = tf.data.TFRecordDataset('./trans/test.tfrecords')
    dataset_test = dataset_test.map(parse_test)
    dataset_test = dataset_test.shuffle(buffer_size=10000)
    dataset_test = dataset_test.batch(128)
    iterator_test = dataset_test.make_initializable_iterator()
    next_element_test = iterator_test.get_next()
    if network == 'resnet50':
        prob_test = resnet50(x_input,
                             is_training=False,
                             reuse=True,
                             kernel_initializer=None)
    elif network == 'resnet18':
        prob_test = resnet18(x_input,
                             is_training=False,
                             reuse=True,
                             kernel_initializer=None)
    elif network == 'resnet34':
        prob_test = resnet34(x_input,
                             is_training=False,
                             reuse=True,
                             kernel_initializer=None)
    elif network == 'seresnet50':
        prob_test = se_resnet50(x_input,
                                is_training=False,
                                reuse=True,
                                kernel_initializer=None)
    elif network == 'resnet110':
        prob_test = resnet110(x_input,
                              is_training=False,
                              reuse=True,
                              kernel_initializer=None)
    elif network == 'seresnet110':
        prob_test = se_resnet110(x_input,
                                 is_training=False,
                                 reuse=True,
                                 kernel_initializer=None)
    elif network == 'seresnet152':
        prob_test = se_resnet152(x_input,
                                 is_training=False,
                                 reuse=True,
                                 kernel_initializer=None)
    elif network == 'resnet152':
        prob_test = resnet152(x_input,
                              is_training=False,
                              reuse=True,
                              kernel_initializer=None)
    elif network == 'seresnet_fixed':
        prob_test = get_resnet(x_input,
                               152,
                               type='se_ir',
                               trainable=False,
                               reuse=True)
    elif network == 'densenet121':
        prob_test = densenet121(x_input,
                                is_training=False,
                                reuse=True,
                                kernel_initializer=None)
    elif network == 'densenet169':
        prob_test = densenet169(x_input,
                                is_training=False,
                                reuse=True,
                                kernel_initializer=None)
    elif network == 'densenet201':
        prob_test = densenet201(x_input,
                                is_training=False,
                                reuse=True,
                                kernel_initializer=None)
    elif network == 'densenet161':
        prob_test = densenet161(x_input,
                                is_training=False,
                                reuse=True,
                                kernel_initializer=None)
    elif network == 'densenet100bc':
        prob_test = densenet100bc(x_input,
                                  reuse=True,
                                  is_training=False,
                                  kernel_initializer=None)
    elif network == 'densenet190bc':
        prob_test = densenet190bc(x_input,
                                  reuse=True,
                                  is_training=False,
                                  kernel_initializer=None)
    elif network == 'resnext50':
        prob_test = resnext50(x_input,
                              is_training=False,
                              reuse=True,
                              cardinality=32,
                              kernel_initializer=None)
    elif network == 'resnext110':
        prob_test = resnext110(x_input,
                               is_training=False,
                               reuse=True,
                               cardinality=32,
                               kernel_initializer=None)
    elif network == 'resnext152':
        prob_test = resnext152(x_input,
                               is_training=False,
                               reuse=True,
                               cardinality=32,
                               kernel_initializer=None)
    elif network == 'seresnext50':
        prob_test = se_resnext50(x_input,
                                 reuse=True,
                                 is_training=False,
                                 cardinality=32,
                                 kernel_initializer=None)
    elif network == 'seresnext110':
        prob_test = se_resnext110(x_input,
                                  reuse=True,
                                  is_training=False,
                                  cardinality=32,
                                  kernel_initializer=None)
    elif network == 'seresnext152':
        prob_test = se_resnext152(x_input,
                                  reuse=True,
                                  is_training=False,
                                  cardinality=32,
                                  kernel_initializer=None)

    logit_softmax_test = tf.nn.softmax(prob_test)
    acc_test = tf.reduce_sum(
        tf.cast(tf.equal(tf.argmax(logit_softmax_test, 1), y_input),
                tf.float32))
    #----------------------------------------------------------------------------
    saver = tf.train.Saver(max_to_keep=1, var_list=tf.global_variables())
    config = tf.ConfigProto()
    config.allow_soft_placement = True
    config.gpu_options.allow_growth = True

    now_lr = 0.001  # Warm Up
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        counter = 0
        max_test_acc = -1
        for i in range(epoch):
            sess.run(iterator.initializer)
            while True:
                try:
                    batch_train, label_train = sess.run(next_element)
                    _, loss_val, acc_val, lr_val = sess.run(
                        [train_op, loss, acc, lr],
                        feed_dict={
                            x_input: batch_train,
                            y_input: label_train,
                            lr: now_lr
                        })

                    counter += 1

                    if counter % 100 == 0:
                        print('counter: ', counter, 'loss_val', loss_val,
                              'acc: ', acc_val)
                    if counter % 1000 == 0:
                        print('start test ')
                        sess.run(iterator_test.initializer)
                        avg_acc = []
                        while True:
                            try:
                                batch_test, label_test = sess.run(
                                    next_element_test)
                                acc_test_val = sess.run(acc_test,
                                                        feed_dict={
                                                            x_input:
                                                            batch_test,
                                                            y_input: label_test
                                                        })
                                avg_acc.append(acc_test_val)
                            except tf.errors.OutOfRangeError:
                                print('end test ',
                                      np.sum(avg_acc) / len(y_test))
                                now_test_acc = np.sum(avg_acc) / len(y_test)
                                if now_test_acc > max_test_acc:
                                    print('***** Max test changed: ',
                                          now_test_acc)
                                    max_test_acc = now_test_acc
                                    filename = 'params/distinct/' + network + '_{}.ckpt'.format(
                                        counter)
                                    saver.save(sess, filename)
                                break
                except tf.errors.OutOfRangeError:
                    print('end epoch %d/%d , lr: %f' % (i, epoch, lr_val))
                    now_lr = lr_schedule(i, args.epoch)
                    break
示例#6
0
def test(args):
    # train = unpickle('/data/ChuyuanXiong/up/cifar-100-python/train')
    # train_data = train[b'data']
    # x_train = train_data.reshape(train_data.shape[0], 3, 32, 32)
    # x_train = x_train.transpose(0, 2, 3, 1)

    test = unpickle(args.test_path)
    test_data = test[b'data']

    x_test = test_data.reshape(test_data.shape[0], 3, 32, 32)
    x_test = x_test.transpose(0, 2, 3, 1)
    y_test = test[b'fine_labels']

    x_test = norm_images(x_test)
    # x_test = norm_images_using_mean_var(x_test, *compute_mean_var(x_train))

    network = args.network
    ckpt = args.ckpt

    x_input = tf.placeholder(tf.float32, [None, 32, 32, 3])
    y_input = tf.placeholder(tf.int64, [
        None,
    ])
    #-------------------------------Test-----------------------------------------
    if not os.path.exists('./trans/test.tfrecords'):
        generate_tfrecord(x_test, y_test, './trans/', 'test.tfrecords')
    dataset_test = tf.data.TFRecordDataset('./trans/test.tfrecords')
    dataset_test = dataset_test.map(parse_test)
    dataset_test = dataset_test.shuffle(buffer_size=10000)
    dataset_test = dataset_test.batch(128)
    iterator_test = dataset_test.make_initializable_iterator()
    next_element_test = iterator_test.get_next()
    if network == 'resnet50':
        prob_test = resnet50(x_input,
                             is_training=False,
                             reuse=False,
                             kernel_initializer=None)
    elif network == 'resnet18':
        prob_test = resnet18(x_input,
                             is_training=False,
                             reuse=False,
                             kernel_initializer=None)
    elif network == 'resnet34':
        prob_test = resnet34(x_input,
                             is_training=False,
                             reuse=False,
                             kernel_initializer=None)
    elif network == 'seresnet50':
        prob_test = se_resnet50(x_input,
                                is_training=False,
                                reuse=False,
                                kernel_initializer=None)
    elif network == 'resnet110':
        prob_test = resnet110(x_input,
                              is_training=False,
                              reuse=False,
                              kernel_initializer=None)
    elif network == 'seresnet110':
        prob_test = se_resnet110(x_input,
                                 is_training=False,
                                 reuse=False,
                                 kernel_initializer=None)
    elif network == 'seresnet152':
        prob_test = se_resnet152(x_input,
                                 is_training=False,
                                 reuse=False,
                                 kernel_initializer=None)
    elif network == 'resnet152':
        prob_test = resnet152(x_input,
                              is_training=False,
                              reuse=False,
                              kernel_initializer=None)
    elif network == 'seresnet_fixed':
        prob_test = get_resnet(x_input,
                               152,
                               type='se_ir',
                               trainable=False,
                               reuse=True)
    elif network == 'densenet121':
        prob_test = densenet121(x_input,
                                is_training=False,
                                reuse=False,
                                kernel_initializer=None)
    elif network == 'densenet169':
        prob_test = densenet169(x_input,
                                is_training=False,
                                reuse=False,
                                kernel_initializer=None)
    elif network == 'densenet201':
        prob_test = densenet201(x_input,
                                is_training=False,
                                reuse=False,
                                kernel_initializer=None)
    elif network == 'densenet161':
        prob_test = densenet161(x_input,
                                is_training=False,
                                reuse=False,
                                kernel_initializer=None)
    elif network == 'densenet100bc':
        prob_test = densenet100bc(x_input,
                                  reuse=False,
                                  is_training=False,
                                  kernel_initializer=None)
    elif network == 'densenet190bc':
        prob_test = densenet190bc(x_input,
                                  reuse=False,
                                  is_training=False,
                                  kernel_initializer=None)
    elif network == 'resnext50':
        prob_test = resnext50(x_input,
                              is_training=False,
                              reuse=False,
                              cardinality=32,
                              kernel_initializer=None)
    elif network == 'resnext110':
        prob_test = resnext110(x_input,
                               is_training=False,
                               reuse=False,
                               cardinality=32,
                               kernel_initializer=None)
    elif network == 'resnext152':
        prob_test = resnext152(x_input,
                               is_training=False,
                               reuse=False,
                               cardinality=32,
                               kernel_initializer=None)
    elif network == 'seresnext50':
        prob_test = se_resnext50(x_input,
                                 reuse=False,
                                 is_training=False,
                                 cardinality=32,
                                 kernel_initializer=None)
    elif network == 'seresnext110':
        prob_test = se_resnext110(x_input,
                                  reuse=False,
                                  is_training=False,
                                  cardinality=32,
                                  kernel_initializer=None)
    elif network == 'seresnext152':
        prob_test = se_resnext152(x_input,
                                  reuse=False,
                                  is_training=False,
                                  cardinality=32,
                                  kernel_initializer=None)

    # prob_test = tf.layers.dense(prob_test, 100, reuse=True, name='before_softmax')
    logit_softmax_test = tf.nn.softmax(prob_test)
    acc_test = tf.reduce_sum(
        tf.cast(tf.equal(tf.argmax(logit_softmax_test, 1), y_input),
                tf.float32))

    var_list = tf.trainable_variables()
    g_list = tf.global_variables()
    bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name]
    bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name]
    var_list += bn_moving_vars

    saver = tf.train.Saver(var_list=var_list)
    config = tf.ConfigProto()
    config.allow_soft_placement = True
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:
        saver.restore(sess, ckpt)
        sess.run(iterator_test.initializer)
        avg_acc = []
        while True:
            try:
                batch_test, label_test = sess.run(next_element_test)
                acc_test_val = sess.run(acc_test,
                                        feed_dict={
                                            x_input: batch_test,
                                            y_input: label_test
                                        })
                avg_acc.append(acc_test_val)
            except tf.errors.OutOfRangeError:
                print('end test ', np.sum(avg_acc) / len(y_test))
                break
示例#7
0
time.sleep(1.5)

# config
batch_size = 40
epoch_total = 1000
n_dis = 1
is_cuda = True
display_step = 10
save_step = 100
# net_pretrain = '/home/elijha/Documents/PycharmProjects/ReIDGAN/workdir/save/save-3400'
net_pretrain = True

# build graph
if net_pretrain:
    feature_extractor = resnet50(
        pretrained_path="/home/nhli/PycharmProj/ReIDGAN_/params/save-fea-2000"
    )  # pretrained
    dis = Discriminator()
    dis.load_state_dict(
        torch.load("/home/nhli/PycharmProj/ReIDGAN_/params/save-dis-2000"))
else:
    feature_extractor = resnet50(pretrained_path=None)
    dis = Discriminator()

if is_cuda:
    feature_extractor.cuda()
    dis.cuda()

# input pipeline
data_iter = DataProvider(batch_size, is_cuda=is_cuda)