Ejemplo n.º 1
0
 def __init__(self, out_size):
     super(DenseNet121, self).__init__()
     self.inplanes = 1024
     self.densenet121 = densenet.densenet121(
         pretrained=True, small=args.small)  # small = 1 pool = 0
     num_ftrs = self.densenet121.classifier.in_features  #1024
     self.classifier_font = nn.Sequential(
         # 这里可以用fc做分类
         # nn.Linear(num_ftrs, out_size)
         # 这里可以用1×1卷积做分类
         nn.Conv2d(num_ftrs, out_size, kernel_size=1,
                   bias=False))  # 直接接一个卷积然后分为1823个类
     self.train_params = []
     self.unpool = nn.MaxUnpool2d(kernel_size=2, stride=2)
Ejemplo n.º 2
0
def create_model_optimizer_scheduler(args, dataset_class, optimizer='adam', scheduler='steplr',
                                     load_optimizer_scheduler=False):
    if args.arch == 'wideresnet':
        model = WideResNet(depth=args.layers,
                           num_classes=dataset_class.num_classes,
                           widen_factor=args.widen_factor,
                           dropout_rate=args.drop_rate)
    elif args.arch == 'densenet':
        model = densenet121(num_classes=dataset_class.num_classes)
    elif args.arch == 'lenet':
        model = LeNet(num_channels=3, num_classes=dataset_class.num_classes,
                      droprate=args.drop_rate, input_size=dataset_class.input_size)
    elif args.arch == 'resnet':
        model = resnet18(num_classes=dataset_class.num_classes, input_size=dataset_class.input_size,
                         drop_rate=args.drop_rate)
    else:
        raise NotImplementedError

    print('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))

    model = model.cuda()

    if optimizer == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    else:
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum,
                                    nesterov=args.nesterov, weight_decay=args.weight_decay)

    if scheduler == 'steplr':
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.2)
    else:
        args.iteration = args.fixmatch_k_img // args.batch_size
        args.total_steps = args.fixmatch_epochs * args.iteration
        scheduler = get_cosine_schedule_with_warmup(
            optimizer, args.fixmatch_warmup * args.iteration, args.total_steps)

    if args.resume:
        if load_optimizer_scheduler:
            model, optimizer, scheduler = resume_model(args, model, optimizer, scheduler)
        else:
            model, _, _ = resume_model(args, model)

    return model, optimizer, scheduler
Ejemplo n.º 3
0
 def get_model(x_input, network):
     if network == 'resnet50':
         return resnet50(x_input,
                         is_training=False,
                         reuse=False,
                         kernel_initializer=None)
     elif network == 'resnet18':
         return resnet18(x_input,
                         is_training=False,
                         reuse=False,
                         kernel_initializer=None)
     elif network == 'resnet34':
         return resnet34(x_input,
                         is_training=False,
                         reuse=False,
                         kernel_initializer=None)
     elif network == 'seresnet50':
         return se_resnet50(x_input,
                            is_training=False,
                            reuse=False,
                            kernel_initializer=None)
     elif network == 'resnet110':
         return resnet110(x_input,
                          is_training=False,
                          reuse=False,
                          kernel_initializer=None)
     elif network == 'seresnet110':
         return se_resnet110(x_input,
                             is_training=False,
                             reuse=False,
                             kernel_initializer=None)
     elif network == 'seresnet152':
         return se_resnet152(x_input,
                             is_training=False,
                             reuse=False,
                             kernel_initializer=None)
     elif network == 'resnet152':
         return resnet152(x_input,
                          is_training=False,
                          reuse=False,
                          kernel_initializer=None)
     elif network == 'seresnet_fixed':
         return get_resnet(x_input,
                           152,
                           type='se_ir',
                           trainable=False,
                           reuse=True)
     elif network == 'densenet121':
         return densenet121(x_input,
                            is_training=False,
                            reuse=False,
                            kernel_initializer=None)
     elif network == 'densenet169':
         return densenet169(x_input,
                            is_training=False,
                            reuse=False,
                            kernel_initializer=None)
     elif network == 'densenet201':
         return densenet201(x_input,
                            is_training=False,
                            reuse=False,
                            kernel_initializer=None)
     elif network == 'densenet161':
         return densenet161(x_input,
                            is_training=False,
                            reuse=False,
                            kernel_initializer=None)
     elif network == 'densenet100bc':
         return densenet100bc(x_input,
                              reuse=True,
                              is_training=False,
                              kernel_initializer=None)
     elif network == 'densenet190bc':
         return densenet190bc(x_input,
                              reuse=True,
                              is_training=False,
                              kernel_initializer=None)
     elif network == 'resnext50':
         return resnext50(x_input,
                          is_training=False,
                          reuse=False,
                          cardinality=32,
                          kernel_initializer=None)
     elif network == 'resnext110':
         return resnext110(x_input,
                           is_training=False,
                           reuse=False,
                           cardinality=32,
                           kernel_initializer=None)
     elif network == 'resnext152':
         return resnext152(x_input,
                           is_training=False,
                           reuse=False,
                           cardinality=32,
                           kernel_initializer=None)
     elif network == 'seresnext50':
         return se_resnext50(x_input,
                             reuse=True,
                             is_training=False,
                             cardinality=32,
                             kernel_initializer=None)
     elif network == 'seresnext110':
         return se_resnext110(x_input,
                              reuse=True,
                              is_training=False,
                              cardinality=32,
                              kernel_initializer=None)
     elif network == 'seresnext152':
         return se_resnext152(x_input,
                              reuse=True,
                              is_training=False,
                              cardinality=32,
                              kernel_initializer=None)
     raise InvalidNetworkName('Network name is invalid!')
Ejemplo n.º 4
0
def train(args):
    batch_size = args.batch_size
    epoch = args.epoch
    network = args.network
    opt = args.opt
    train = unpickle(args.train_path)
    test = unpickle(args.test_path)
    train_data = train[b'data']
    test_data = test[b'data']

    x_train = train_data.reshape(train_data.shape[0], 3, 32, 32)
    x_train = x_train.transpose(0, 2, 3, 1)
    y_train = train[b'fine_labels']

    x_test = test_data.reshape(test_data.shape[0], 3, 32, 32)
    x_test = x_test.transpose(0, 2, 3, 1)
    y_test = test[b'fine_labels']

    x_train = norm_images(x_train)
    x_test = norm_images(x_test)

    print('-------------------------------')
    print('--train/test len: ', len(train_data), len(test_data))
    print('--x_train norm: ', compute_mean_var(x_train))
    print('--x_test norm: ', compute_mean_var(x_test))
    print('--batch_size: ', batch_size)
    print('--epoch: ', epoch)
    print('--network: ', network)
    print('--opt: ', opt)
    print('-------------------------------')

    if not os.path.exists('./trans/tran.tfrecords'):
        generate_tfrecord(x_train, y_train, './trans/', 'tran.tfrecords')
        generate_tfrecord(x_test, y_test, './trans/', 'test.tfrecords')

    dataset = tf.data.TFRecordDataset('./trans/tran.tfrecords')
    dataset = dataset.map(parse_function)
    dataset = dataset.shuffle(buffer_size=50000)
    dataset = dataset.batch(batch_size)
    iterator = dataset.make_initializable_iterator()
    next_element = iterator.get_next()

    x_input = tf.placeholder(tf.float32, [None, 32, 32, 3])
    y_input = tf.placeholder(tf.int64, [
        None,
    ])
    y_input_one_hot = tf.one_hot(y_input, 100)
    lr = tf.placeholder(tf.float32, [])

    if network == 'resnet50':
        prob = resnet50(x_input,
                        is_training=True,
                        reuse=False,
                        kernel_initializer=tf.orthogonal_initializer())
    elif network == 'resnet34':
        prob = resnet34(x_input,
                        is_training=True,
                        reuse=False,
                        kernel_initializer=tf.contrib.layers.
                        variance_scaling_initializer())
    elif network == 'resnet18':
        prob = resnet18(x_input,
                        is_training=True,
                        reuse=False,
                        kernel_initializer=tf.contrib.layers.
                        variance_scaling_initializer())
    elif network == 'seresnet50':
        prob = se_resnet50(x_input,
                           is_training=True,
                           reuse=False,
                           kernel_initializer=tf.orthogonal_initializer())
    elif network == 'resnet110':
        prob = resnet110(x_input,
                         is_training=True,
                         reuse=False,
                         kernel_initializer=tf.orthogonal_initializer())
    elif network == 'seresnet110':
        prob = se_resnet110(x_input,
                            is_training=True,
                            reuse=False,
                            kernel_initializer=tf.orthogonal_initializer())
    elif network == 'seresnet152':
        prob = se_resnet152(x_input,
                            is_training=True,
                            reuse=False,
                            kernel_initializer=tf.orthogonal_initializer())
    elif network == 'resnet152':
        prob = resnet152(x_input,
                         is_training=True,
                         kernel_initializer=tf.orthogonal_initializer())
    elif network == 'seresnet_fixed':
        prob = get_resnet(x_input,
                          152,
                          trainable=True,
                          w_init=tf.orthogonal_initializer())
    elif network == 'densenet121':
        prob = densenet121(x_input,
                           reuse=False,
                           is_training=True,
                           kernel_initializer=tf.orthogonal_initializer())
    elif network == 'densenet169':
        prob = densenet169(x_input,
                           reuse=False,
                           is_training=True,
                           kernel_initializer=tf.orthogonal_initializer())
    elif network == 'densenet201':
        prob = densenet201(x_input,
                           reuse=False,
                           is_training=True,
                           kernel_initializer=tf.orthogonal_initializer())
    elif network == 'densenet161':
        prob = densenet161(x_input,
                           reuse=False,
                           is_training=True,
                           kernel_initializer=tf.orthogonal_initializer())
    elif network == 'densenet100bc':
        prob = densenet100bc(x_input,
                             reuse=False,
                             is_training=True,
                             kernel_initializer=tf.orthogonal_initializer())
    elif network == 'densenet190bc':
        prob = densenet190bc(x_input,
                             reuse=False,
                             is_training=True,
                             kernel_initializer=tf.orthogonal_initializer())
    elif network == 'resnext50':
        prob = resnext50(x_input,
                         reuse=False,
                         is_training=True,
                         cardinality=32,
                         kernel_initializer=tf.orthogonal_initializer())
    elif network == 'resnext110':
        prob = resnext110(x_input,
                          reuse=False,
                          is_training=True,
                          cardinality=32,
                          kernel_initializer=tf.orthogonal_initializer())
    elif network == 'resnext152':
        prob = resnext152(x_input,
                          reuse=False,
                          is_training=True,
                          cardinality=32,
                          kernel_initializer=tf.orthogonal_initializer())
    elif network == 'seresnext50':
        prob = se_resnext50(x_input,
                            reuse=False,
                            is_training=True,
                            cardinality=32,
                            kernel_initializer=tf.orthogonal_initializer())
    elif network == 'seresnext110':
        prob = se_resnext110(x_input,
                             reuse=False,
                             is_training=True,
                             cardinality=32,
                             kernel_initializer=tf.orthogonal_initializer())
    elif network == 'seresnext152':
        prob = se_resnext152(x_input,
                             reuse=False,
                             is_training=True,
                             cardinality=32,
                             kernel_initializer=tf.orthogonal_initializer())

    loss = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits(logits=prob,
                                                labels=y_input_one_hot))

    conv_var = [var for var in tf.trainable_variables() if 'conv' in var.name]
    l2_loss = tf.add_n([tf.nn.l2_loss(var) for var in conv_var])
    loss = l2_loss * 5e-4 + loss

    if opt == 'adam':
        opt = tf.train.AdamOptimizer(lr)
    elif opt == 'momentum':
        opt = tf.train.MomentumOptimizer(lr, 0.9)
    elif opt == 'nesterov':
        opt = tf.train.MomentumOptimizer(lr, 0.9, use_nesterov=True)

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = opt.minimize(loss)

    logit_softmax = tf.nn.softmax(prob)
    acc = tf.reduce_mean(
        tf.cast(tf.equal(tf.argmax(logit_softmax, 1), y_input), tf.float32))

    #-------------------------------Test-----------------------------------------
    if not os.path.exists('./trans/tran.tfrecords'):
        generate_tfrecord(x_test, y_test, './trans/', 'test.tfrecords')
    dataset_test = tf.data.TFRecordDataset('./trans/test.tfrecords')
    dataset_test = dataset_test.map(parse_test)
    dataset_test = dataset_test.shuffle(buffer_size=10000)
    dataset_test = dataset_test.batch(128)
    iterator_test = dataset_test.make_initializable_iterator()
    next_element_test = iterator_test.get_next()
    if network == 'resnet50':
        prob_test = resnet50(x_input,
                             is_training=False,
                             reuse=True,
                             kernel_initializer=None)
    elif network == 'resnet18':
        prob_test = resnet18(x_input,
                             is_training=False,
                             reuse=True,
                             kernel_initializer=None)
    elif network == 'resnet34':
        prob_test = resnet34(x_input,
                             is_training=False,
                             reuse=True,
                             kernel_initializer=None)
    elif network == 'seresnet50':
        prob_test = se_resnet50(x_input,
                                is_training=False,
                                reuse=True,
                                kernel_initializer=None)
    elif network == 'resnet110':
        prob_test = resnet110(x_input,
                              is_training=False,
                              reuse=True,
                              kernel_initializer=None)
    elif network == 'seresnet110':
        prob_test = se_resnet110(x_input,
                                 is_training=False,
                                 reuse=True,
                                 kernel_initializer=None)
    elif network == 'seresnet152':
        prob_test = se_resnet152(x_input,
                                 is_training=False,
                                 reuse=True,
                                 kernel_initializer=None)
    elif network == 'resnet152':
        prob_test = resnet152(x_input,
                              is_training=False,
                              reuse=True,
                              kernel_initializer=None)
    elif network == 'seresnet_fixed':
        prob_test = get_resnet(x_input,
                               152,
                               type='se_ir',
                               trainable=False,
                               reuse=True)
    elif network == 'densenet121':
        prob_test = densenet121(x_input,
                                is_training=False,
                                reuse=True,
                                kernel_initializer=None)
    elif network == 'densenet169':
        prob_test = densenet169(x_input,
                                is_training=False,
                                reuse=True,
                                kernel_initializer=None)
    elif network == 'densenet201':
        prob_test = densenet201(x_input,
                                is_training=False,
                                reuse=True,
                                kernel_initializer=None)
    elif network == 'densenet161':
        prob_test = densenet161(x_input,
                                is_training=False,
                                reuse=True,
                                kernel_initializer=None)
    elif network == 'densenet100bc':
        prob_test = densenet100bc(x_input,
                                  reuse=True,
                                  is_training=False,
                                  kernel_initializer=None)
    elif network == 'densenet190bc':
        prob_test = densenet190bc(x_input,
                                  reuse=True,
                                  is_training=False,
                                  kernel_initializer=None)
    elif network == 'resnext50':
        prob_test = resnext50(x_input,
                              is_training=False,
                              reuse=True,
                              cardinality=32,
                              kernel_initializer=None)
    elif network == 'resnext110':
        prob_test = resnext110(x_input,
                               is_training=False,
                               reuse=True,
                               cardinality=32,
                               kernel_initializer=None)
    elif network == 'resnext152':
        prob_test = resnext152(x_input,
                               is_training=False,
                               reuse=True,
                               cardinality=32,
                               kernel_initializer=None)
    elif network == 'seresnext50':
        prob_test = se_resnext50(x_input,
                                 reuse=True,
                                 is_training=False,
                                 cardinality=32,
                                 kernel_initializer=None)
    elif network == 'seresnext110':
        prob_test = se_resnext110(x_input,
                                  reuse=True,
                                  is_training=False,
                                  cardinality=32,
                                  kernel_initializer=None)
    elif network == 'seresnext152':
        prob_test = se_resnext152(x_input,
                                  reuse=True,
                                  is_training=False,
                                  cardinality=32,
                                  kernel_initializer=None)

    logit_softmax_test = tf.nn.softmax(prob_test)
    acc_test = tf.reduce_sum(
        tf.cast(tf.equal(tf.argmax(logit_softmax_test, 1), y_input),
                tf.float32))
    #----------------------------------------------------------------------------
    saver = tf.train.Saver(max_to_keep=1, var_list=tf.global_variables())
    config = tf.ConfigProto()
    config.allow_soft_placement = True
    config.gpu_options.allow_growth = True

    now_lr = 0.001  # Warm Up
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())

        counter = 0
        max_test_acc = -1
        for i in range(epoch):
            sess.run(iterator.initializer)
            while True:
                try:
                    batch_train, label_train = sess.run(next_element)
                    _, loss_val, acc_val, lr_val = sess.run(
                        [train_op, loss, acc, lr],
                        feed_dict={
                            x_input: batch_train,
                            y_input: label_train,
                            lr: now_lr
                        })

                    counter += 1

                    if counter % 100 == 0:
                        print('counter: ', counter, 'loss_val', loss_val,
                              'acc: ', acc_val)
                    if counter % 1000 == 0:
                        print('start test ')
                        sess.run(iterator_test.initializer)
                        avg_acc = []
                        while True:
                            try:
                                batch_test, label_test = sess.run(
                                    next_element_test)
                                acc_test_val = sess.run(acc_test,
                                                        feed_dict={
                                                            x_input:
                                                            batch_test,
                                                            y_input: label_test
                                                        })
                                avg_acc.append(acc_test_val)
                            except tf.errors.OutOfRangeError:
                                print('end test ',
                                      np.sum(avg_acc) / len(y_test))
                                now_test_acc = np.sum(avg_acc) / len(y_test)
                                if now_test_acc > max_test_acc:
                                    print('***** Max test changed: ',
                                          now_test_acc)
                                    max_test_acc = now_test_acc
                                    filename = 'params/distinct/' + network + '_{}.ckpt'.format(
                                        counter)
                                    saver.save(sess, filename)
                                break
                except tf.errors.OutOfRangeError:
                    print('end epoch %d/%d , lr: %f' % (i, epoch, lr_val))
                    now_lr = lr_schedule(i, args.epoch)
                    break
Ejemplo n.º 5
0
def test(args):
    # train = unpickle('/data/ChuyuanXiong/up/cifar-100-python/train')
    # train_data = train[b'data']
    # x_train = train_data.reshape(train_data.shape[0], 3, 32, 32)
    # x_train = x_train.transpose(0, 2, 3, 1)

    test = unpickle(args.test_path)
    test_data = test[b'data']

    x_test = test_data.reshape(test_data.shape[0], 3, 32, 32)
    x_test = x_test.transpose(0, 2, 3, 1)
    y_test = test[b'fine_labels']

    x_test = norm_images(x_test)
    # x_test = norm_images_using_mean_var(x_test, *compute_mean_var(x_train))

    network = args.network
    ckpt = args.ckpt

    x_input = tf.placeholder(tf.float32, [None, 32, 32, 3])
    y_input = tf.placeholder(tf.int64, [
        None,
    ])
    #-------------------------------Test-----------------------------------------
    if not os.path.exists('./trans/test.tfrecords'):
        generate_tfrecord(x_test, y_test, './trans/', 'test.tfrecords')
    dataset_test = tf.data.TFRecordDataset('./trans/test.tfrecords')
    dataset_test = dataset_test.map(parse_test)
    dataset_test = dataset_test.shuffle(buffer_size=10000)
    dataset_test = dataset_test.batch(128)
    iterator_test = dataset_test.make_initializable_iterator()
    next_element_test = iterator_test.get_next()
    if network == 'resnet50':
        prob_test = resnet50(x_input,
                             is_training=False,
                             reuse=False,
                             kernel_initializer=None)
    elif network == 'resnet18':
        prob_test = resnet18(x_input,
                             is_training=False,
                             reuse=False,
                             kernel_initializer=None)
    elif network == 'resnet34':
        prob_test = resnet34(x_input,
                             is_training=False,
                             reuse=False,
                             kernel_initializer=None)
    elif network == 'seresnet50':
        prob_test = se_resnet50(x_input,
                                is_training=False,
                                reuse=False,
                                kernel_initializer=None)
    elif network == 'resnet110':
        prob_test = resnet110(x_input,
                              is_training=False,
                              reuse=False,
                              kernel_initializer=None)
    elif network == 'seresnet110':
        prob_test = se_resnet110(x_input,
                                 is_training=False,
                                 reuse=False,
                                 kernel_initializer=None)
    elif network == 'seresnet152':
        prob_test = se_resnet152(x_input,
                                 is_training=False,
                                 reuse=False,
                                 kernel_initializer=None)
    elif network == 'resnet152':
        prob_test = resnet152(x_input,
                              is_training=False,
                              reuse=False,
                              kernel_initializer=None)
    elif network == 'seresnet_fixed':
        prob_test = get_resnet(x_input,
                               152,
                               type='se_ir',
                               trainable=False,
                               reuse=True)
    elif network == 'densenet121':
        prob_test = densenet121(x_input,
                                is_training=False,
                                reuse=False,
                                kernel_initializer=None)
    elif network == 'densenet169':
        prob_test = densenet169(x_input,
                                is_training=False,
                                reuse=False,
                                kernel_initializer=None)
    elif network == 'densenet201':
        prob_test = densenet201(x_input,
                                is_training=False,
                                reuse=False,
                                kernel_initializer=None)
    elif network == 'densenet161':
        prob_test = densenet161(x_input,
                                is_training=False,
                                reuse=False,
                                kernel_initializer=None)
    elif network == 'densenet100bc':
        prob_test = densenet100bc(x_input,
                                  reuse=False,
                                  is_training=False,
                                  kernel_initializer=None)
    elif network == 'densenet190bc':
        prob_test = densenet190bc(x_input,
                                  reuse=False,
                                  is_training=False,
                                  kernel_initializer=None)
    elif network == 'resnext50':
        prob_test = resnext50(x_input,
                              is_training=False,
                              reuse=False,
                              cardinality=32,
                              kernel_initializer=None)
    elif network == 'resnext110':
        prob_test = resnext110(x_input,
                               is_training=False,
                               reuse=False,
                               cardinality=32,
                               kernel_initializer=None)
    elif network == 'resnext152':
        prob_test = resnext152(x_input,
                               is_training=False,
                               reuse=False,
                               cardinality=32,
                               kernel_initializer=None)
    elif network == 'seresnext50':
        prob_test = se_resnext50(x_input,
                                 reuse=False,
                                 is_training=False,
                                 cardinality=32,
                                 kernel_initializer=None)
    elif network == 'seresnext110':
        prob_test = se_resnext110(x_input,
                                  reuse=False,
                                  is_training=False,
                                  cardinality=32,
                                  kernel_initializer=None)
    elif network == 'seresnext152':
        prob_test = se_resnext152(x_input,
                                  reuse=False,
                                  is_training=False,
                                  cardinality=32,
                                  kernel_initializer=None)

    # prob_test = tf.layers.dense(prob_test, 100, reuse=True, name='before_softmax')
    logit_softmax_test = tf.nn.softmax(prob_test)
    acc_test = tf.reduce_sum(
        tf.cast(tf.equal(tf.argmax(logit_softmax_test, 1), y_input),
                tf.float32))

    var_list = tf.trainable_variables()
    g_list = tf.global_variables()
    bn_moving_vars = [g for g in g_list if 'moving_mean' in g.name]
    bn_moving_vars += [g for g in g_list if 'moving_variance' in g.name]
    var_list += bn_moving_vars

    saver = tf.train.Saver(var_list=var_list)
    config = tf.ConfigProto()
    config.allow_soft_placement = True
    config.gpu_options.allow_growth = True

    with tf.Session(config=config) as sess:
        saver.restore(sess, ckpt)
        sess.run(iterator_test.initializer)
        avg_acc = []
        while True:
            try:
                batch_test, label_test = sess.run(next_element_test)
                acc_test_val = sess.run(acc_test,
                                        feed_dict={
                                            x_input: batch_test,
                                            y_input: label_test
                                        })
                avg_acc.append(acc_test_val)
            except tf.errors.OutOfRangeError:
                print('end test ', np.sum(avg_acc) / len(y_test))
                break
Ejemplo n.º 6
0
    parser.add_argument('--weights', type=str, default='./checkpoints/densenet/85-best.pth', help='the weights file you want to test')  # 修改点
    args = parser.parse_args()
    config_path = os.path.join(args.path, 'config.yml')
    
    # load config file
    config = Config(config_path)
   
    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(str(e) for e in config.GPU)
    if torch.cuda.is_available():
        config.DEVICE = torch.device("cuda")
        print('\nGPU IS AVAILABLE')
        torch.backends.cudnn.benchmark = True
    else:
        config.DEVICE = torch.device("cpu")

    net = densenet121().to(config.DEVICE)  # 修改

    test_set = ImageFolder(config.TEST_PATH,transform=test_tf)
    test_data=torch.utils.data.DataLoader(test_set, batch_size=config.BATCH_SIZE, shuffle=False)
    
    pth_path = args.weights
    net.load_state_dict(torch.load(pth_path), config.DEVICE)
    ##print(net)
    net.eval()

    correct_1 = 0.0
    correct_5 = 0.0
    total = 0

    for n_iter, (image, label) in enumerate(test_data):
        print("iteration: {}\ttotal {} iterations".format(n_iter + 1, len(test_data)))
Ejemplo n.º 7
0
def main():
    # Load the parameters from json file
    args = parser.parse_args()
    json_path = os.path.join(args.model_dir, 'params.json')
    assert os.path.isfile(
        json_path), "No json configuration file found at {}".format(json_path)
    params = utils.Params(json_path)

    # Set the random seed for reproducible experiments
    random.seed(230)
    torch.manual_seed(230)
    np.random.seed(230)
    torch.cuda.manual_seed(230)
    warnings.filterwarnings("ignore")

    # Set the logger
    utils.set_logger(os.path.join(args.model_dir, 'train.log'))

    # Create the input data pipeline
    logging.info("Loading the datasets...")

    # fetch dataloaders, considering full-set vs. sub-set scenarios
    if params.subset_percent < 1.0:
        train_dl = data_loader.fetch_subset_dataloader('train', params)
    else:
        train_dl = data_loader.fetch_dataloader('train', params)

    dev_dl = data_loader.fetch_dataloader('dev', params)

    logging.info("- done.")
    """
    Load student and teacher model
    """
    if "distill" in params.model_version:

        # Specify the student models
        if params.model_version == "cnn_distill":  # 5-layers Plain CNN
            print("Student model: {}".format(params.model_version))
            model = net.Net(params).cuda()

        elif params.model_version == "shufflenet_v2_distill":
            print("Student model: {}".format(params.model_version))
            model = shufflenet.shufflenetv2(class_num=args.num_class).cuda()

        elif params.model_version == "mobilenet_v2_distill":
            print("Student model: {}".format(params.model_version))
            model = mobilenet.mobilenetv2(class_num=args.num_class).cuda()

        elif params.model_version == 'resnet18_distill':
            print("Student model: {}".format(params.model_version))
            model = resnet.ResNet18(num_classes=args.num_class).cuda()

        elif params.model_version == 'resnet50_distill':
            print("Student model: {}".format(params.model_version))
            model = resnet.ResNet50(num_classes=args.num_class).cuda()

        elif params.model_version == "alexnet_distill":
            print("Student model: {}".format(params.model_version))
            model = alexnet.alexnet(num_classes=args.num_class).cuda()

        elif params.model_version == "vgg19_distill":
            print("Student model: {}".format(params.model_version))
            model = models.vgg19_bn(num_classes=args.num_class).cuda()

        elif params.model_version == "googlenet_distill":
            print("Student model: {}".format(params.model_version))
            model = googlenet.GoogleNet(num_class=args.num_class).cuda()

        elif params.model_version == "resnext29_distill":
            print("Student model: {}".format(params.model_version))
            model = resnext.CifarResNeXt(cardinality=8,
                                         depth=29,
                                         num_classes=args.num_class).cuda()

        elif params.model_version == "densenet121_distill":
            print("Student model: {}".format(params.model_version))
            model = densenet.densenet121(num_class=args.num_class).cuda()

        # optimizer
        if params.model_version == "cnn_distill":
            optimizer = optim.Adam(model.parameters(),
                                   lr=params.learning_rate *
                                   (params.batch_size / 128))
        else:
            optimizer = optim.SGD(model.parameters(),
                                  lr=params.learning_rate *
                                  (params.batch_size / 128),
                                  momentum=0.9,
                                  weight_decay=5e-4)

        iter_per_epoch = len(train_dl)
        warmup_scheduler = utils.WarmUpLR(
            optimizer, iter_per_epoch *
            args.warm)  # warmup the learning rate in the first epoch

        # specify loss function
        if args.self_training:
            print(
                '>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>self training>>>>>>>>>>>>>>>>>>>>>>>>>>>>>'
            )
            loss_fn_kd = loss_kd_self
        else:
            loss_fn_kd = loss_kd
        """ 
            Specify the pre-trained teacher models for knowledge distillation
            Checkpoints can be obtained by regular training or downloading our pretrained models
            For model which is pretrained in multi-GPU, use "nn.DaraParallel" to correctly load the model weights.
        """
        if params.teacher == "resnet18":
            print("Teacher model: {}".format(params.teacher))
            teacher_model = resnet.ResNet18(num_classes=args.num_class)
            teacher_checkpoint = 'experiments/pretrained_teacher_models/base_resnet18/best.pth.tar'
            if args.pt_teacher:  # poorly-trained teacher for Defective KD experiments
                teacher_checkpoint = 'experiments/pretrained_teacher_models/base_resnet18/0.pth.tar'
            teacher_model = teacher_model.cuda()

        elif params.teacher == "alexnet":
            print("Teacher model: {}".format(params.teacher))
            teacher_model = alexnet.alexnet(num_classes=args.num_class)
            teacher_checkpoint = 'experiments/pretrained_teacher_models/base_alexnet/best.pth.tar'
            teacher_model = teacher_model.cuda()

        elif params.teacher == "googlenet":
            print("Teacher model: {}".format(params.teacher))
            teacher_model = googlenet.GoogleNet(num_class=args.num_class)
            teacher_checkpoint = 'experiments/pretrained_teacher_models/base_googlenet/best.pth.tar'
            teacher_model = teacher_model.cuda()

        elif params.teacher == "vgg19":
            print("Teacher model: {}".format(params.teacher))
            teacher_model = models.vgg19_bn(num_classes=args.num_class)
            teacher_checkpoint = 'experiments/pretrained_teacher_models/base_vgg19/best.pth.tar'
            teacher_model = teacher_model.cuda()

        elif params.teacher == "resnet50":
            print("Teacher model: {}".format(params.teacher))
            teacher_model = resnet.ResNet50(num_classes=args.num_class).cuda()
            teacher_checkpoint = 'experiments/pretrained_teacher_models/base_resnet50/best.pth.tar'
            if args.pt_teacher:  # poorly-trained teacher for Defective KD experiments
                teacher_checkpoint = 'experiments/pretrained_teacher_models/base_resnet50/50.pth.tar'

        elif params.teacher == "resnet101":
            print("Teacher model: {}".format(params.teacher))
            teacher_model = resnet.ResNet101(num_classes=args.num_class).cuda()
            teacher_checkpoint = 'experiments/pretrained_teacher_models/base_resnet101/best.pth.tar'
            teacher_model = teacher_model.cuda()

        elif params.teacher == "densenet121":
            print("Teacher model: {}".format(params.teacher))
            teacher_model = densenet.densenet121(
                num_class=args.num_class).cuda()
            teacher_checkpoint = 'experiments/pretrained_teacher_models/base_densenet121/best.pth.tar'
            # teacher_model = nn.DataParallel(teacher_model).cuda()

        elif params.teacher == "resnext29":
            print("Teacher model: {}".format(params.teacher))
            teacher_model = resnext.CifarResNeXt(
                cardinality=8, depth=29, num_classes=args.num_class).cuda()
            teacher_checkpoint = 'experiments/pretrained_teacher_models/base_resnext29/best.pth.tar'
            if args.pt_teacher:  # poorly-trained teacher for Defective KD experiments
                teacher_checkpoint = 'experiments/pretrained_teacher_models/base_resnext29/50.pth.tar'
                teacher_model = nn.DataParallel(teacher_model).cuda()

        elif params.teacher == "mobilenet_v2":
            print("Teacher model: {}".format(params.teacher))
            teacher_model = mobilenet.mobilenetv2(
                class_num=args.num_class).cuda()
            teacher_checkpoint = 'experiments/pretrained_teacher_models/base_mobilenet_v2/best.pth.tar'

        elif params.teacher == "shufflenet_v2":
            print("Teacher model: {}".format(params.teacher))
            teacher_model = shufflenet.shufflenetv2(
                class_num=args.num_class).cuda()
            teacher_checkpoint = 'experiments/pretrained_teacher_models/base_shufflenet_v2/best.pth.tar'

        utils.load_checkpoint(teacher_checkpoint, teacher_model)

        # Train the model with KD
        logging.info("Starting training for {} epoch(s)".format(
            params.num_epochs))
        train_and_evaluate_kd(model, teacher_model, train_dl, dev_dl,
                              optimizer, loss_fn_kd, warmup_scheduler, params,
                              args, args.restore_file)

    # non-KD mode: regular training to obtain a baseline model
    else:
        print("Train base model")
        if params.model_version == "cnn":
            model = net.Net(params).cuda()

        elif params.model_version == "mobilenet_v2":
            print("model: {}".format(params.model_version))
            model = mobilenet.mobilenetv2(class_num=args.num_class).cuda()

        elif params.model_version == "shufflenet_v2":
            print("model: {}".format(params.model_version))
            model = shufflenet.shufflenetv2(class_num=args.num_class).cuda()

        elif params.model_version == "alexnet":
            print("model: {}".format(params.model_version))
            model = alexnet.alexnet(num_classes=args.num_class).cuda()

        elif params.model_version == "vgg19":
            print("model: {}".format(params.model_version))
            model = models.vgg19_bn(num_classes=args.num_class).cuda()

        elif params.model_version == "googlenet":
            print("model: {}".format(params.model_version))
            model = googlenet.GoogleNet(num_class=args.num_class).cuda()

        elif params.model_version == "densenet121":
            print("model: {}".format(params.model_version))
            model = densenet.densenet121(num_class=args.num_class).cuda()

        elif params.model_version == "resnet18":
            model = resnet.ResNet18(num_classes=args.num_class).cuda()

        elif params.model_version == "resnet50":
            model = resnet.ResNet50(num_classes=args.num_class).cuda()

        elif params.model_version == "resnet101":
            model = resnet.ResNet101(num_classes=args.num_class).cuda()

        elif params.model_version == "resnet152":
            model = resnet.ResNet152(num_classes=args.num_class).cuda()

        elif params.model_version == "resnext29":
            model = resnext.CifarResNeXt(cardinality=8,
                                         depth=29,
                                         num_classes=args.num_class).cuda()
            # model = nn.DataParallel(model).cuda()

        if args.regularization:
            print(
                ">>>>>>>>>>>>>>>>>>>>>>>>Loss of Regularization>>>>>>>>>>>>>>>>>>>>>>>>"
            )
            loss_fn = loss_kd_regularization
        elif args.label_smoothing:
            print(
                ">>>>>>>>>>>>>>>>>>>>>>>>Label Smoothing>>>>>>>>>>>>>>>>>>>>>>>>"
            )
            loss_fn = loss_label_smoothing
        else:
            print(
                ">>>>>>>>>>>>>>>>>>>>>>>>Normal Training>>>>>>>>>>>>>>>>>>>>>>>>"
            )
            loss_fn = nn.CrossEntropyLoss()
            if args.double_training:  # double training, compare to self-KD
                print(
                    ">>>>>>>>>>>>>>>>>>>>>>>>Double Training>>>>>>>>>>>>>>>>>>>>>>>>"
                )
                checkpoint = 'experiments/pretrained_teacher_models/base_' + str(
                    params.model_version) + '/best.pth.tar'
                utils.load_checkpoint(checkpoint, model)

        if params.model_version == "cnn":
            optimizer = optim.Adam(model.parameters(),
                                   lr=params.learning_rate *
                                   (params.batch_size / 128))
        else:
            optimizer = optim.SGD(model.parameters(),
                                  lr=params.learning_rate *
                                  (params.batch_size / 128),
                                  momentum=0.9,
                                  weight_decay=5e-4)

        iter_per_epoch = len(train_dl)
        warmup_scheduler = utils.WarmUpLR(optimizer,
                                          iter_per_epoch * args.warm)

        # Train the model
        logging.info("Starting training for {} epoch(s)".format(
            params.num_epochs))
        train_and_evaluate(model, train_dl, dev_dl, optimizer, loss_fn, params,
                           args.model_dir, warmup_scheduler, args,
                           args.restore_file)