예제 #1
0
def get_model(arch, num_classes, channels=3):
    """
    Args:
        arch: string, Network architecture
        num_classes: int, Number of classes
        channels: int, Number of input channels
    Returns:
        model, nn.Module, generated model
    """
    if arch.lower() == "resnet18":
        model = ResNet18(channels, num_classes)
    elif arch.lower() == "resnet34":
        model = ResNet34(channels, num_classes)
    elif arch.lower() == "resnet50":
        model = ResNet50(channels, num_classes)
    elif arch.lower() == "resnet101":
        model = ResNet101(channels, num_classes)
    elif arch.lower() == "resnet152":
        model = ResNet152(channels, num_classes)
    elif arch.lower() == "mobilenet_v1":
        model = MobileNetV1(num_classes, channels)
    elif arch.lower() == "mobilenet_v2":
        model = MobileNetV2(num_classes, channels)
    else:
        raise NotImplementedError(
            f"{arch} not implemented. "
            f"For supported architectures see documentation")
    return model
예제 #2
0
파일: run_covid.py 프로젝트: tmquan/COVID
    def build_graph(self, image, label):
        image = image / 128.0 - 1.0

        if self.args.name == 'VGG16':
            logit, recon = VGG16(image, classes=self.args.types)
        elif self.args.name == 'ShuffleNet':
            logit = ShuffleNet(image, classes=self.args.types)
        elif self.args.name == 'ResNet101':
            logit, recon = ResNet101(image,
                                     mode=self.args.mode,
                                     classes=self.args.types)
        elif self.args.name == 'DenseNet121':
            logit, recon = DenseNet121(image, classes=self.args.types)
        elif self.args.name == 'DenseNet169':
            logit, recon = DenseNet169(image, classes=self.args.types)
        elif self.args.name == 'DenseNet201':
            logit, recon = DenseNet201(image, classes=self.args.types)
        elif self.args.name == 'InceptionBN':
            logit = InceptionBN(image, classes=self.args.types)
        else:
            pass

        estim = tf.sigmoid(logit, name='estim')
        loss_xent = class_balanced_sigmoid_cross_entropy(logit,
                                                         label,
                                                         name='loss_xent')
        # loss_dice = tf.identity(1.0 - dice_coe(estim, label, axis=[0,1], loss_type='jaccard'),
        #                          name='loss_dice')
        # # Reconstruction
        # with argscope([Conv2D, Conv2DTranspose], use_bias=False,
        #               kernel_initializer=tf.random_normal_initializer(stddev=0.02)), \
        #         argscope([Conv2D, Conv2DTranspose, InstanceNorm], data_format='channels_first'):
        #     recon = (LinearWrap(recon)
        #              .Conv2DTranspose('deconv0', 64 * 8, 3, strides=2)
        #              .Conv2DTranspose('deconv1', 64 * 8, 3, strides=2)
        #              .Conv2DTranspose('deconv2', 64 * 4, 3, strides=2)
        #              .Conv2DTranspose('deconv3', 64 * 2, 3, strides=2)
        #              .Conv2DTranspose('deconv4', 64 * 1, 3, strides=2)
        #              .tf.pad([[0, 0], [0, 0], [3, 3], [3, 3]], mode='SYMMETRIC')
        #              .Conv2D('recon', 1, 7, padding='VALID', activation=tf.tanh, use_bias=True)())
        #     recon = tf.transpose(recon, [0, 2, 3, 1])
        # loss_mae = tf.reduce_mean(tf.abs(recon-image), name='loss_mae')
        # Visualization
        visualize_tensors('image', [image],
                          scale_func=lambda x: x * 128.0 + 128.0,
                          max_outputs=max(64, self.args.batch))
        # Regularize the weight of model
        wd_w = tf.train.exponential_decay(2e-4, get_global_step_var(), 80000,
                                          0.7, True)
        wd_cost = tf.multiply(wd_w,
                              regularize_cost('.*/W', tf.nn.l2_loss),
                              name='wd_cost')

        add_param_summary(('.*/W', ['histogram']))  # monitor W
        cost = tf.add_n([loss_xent, wd_cost], name='cost')
        add_moving_summary(loss_xent)
        add_moving_summary(wd_cost)
        add_moving_summary(cost)
        return cost
예제 #3
0
def initialize_model(model_name):
    '''
    Initialise a model with a custom head to predict both sequence length and digits

    Parameters
    ----------
    model_name : str
        Model Name can be either:
        ResNet
        VGG
        BaselineCNN
        ConvNet
        BaselineCNN_dropout
        
    Returns
    -------
    model : object
        The model to be initialize 

    '''

    if model_name[:3] == "VGG":
        model = VGG(model_name, num_classes=7)
        model.classifier = CustomHead(512)

    elif model_name[:6] == "ResNet":
        if model_name == "ResNet18":
            model = ResNet18(num_classes=7)
            model.linear = CustomHead(512)

        elif model_name == "ResNet34":
            model = ResNet18(num_classes=7)
            model.linear = CustomHead(512)

        elif model_name == "ResNet50":
            model = ResNet50(num_classes=7)
            model.linear = CustomHead(512 * 4)

        elif model_name == "ResNet101":
            model = ResNet101(num_classes=7)
            model.linear = CustomHead(512 * 4)

        elif model_name == "ResNet152":
            model = ResNet152(num_classes=7)
            model.linear = CustomHead(512 * 4)

    elif model_name == "BaselineCNN":
        model = BaselineCNN(num_classes=7)
        model.fc2 = CustomHead(4096)

    elif model_name == "BaselineCNN_dropout":
        model = BaselineCNN_dropout(num_classes=7, p=0.5)
        model.fc2 = CustomHead(4096)

    return model
예제 #4
0
def get_model(model):
    model_path = '../saved'
    if model == 'LeNet-5':
        net = LeNet()
        model_name = 'lenet.pth'
    elif model == 'VGG-16':
        net = Vgg16_Net()
        model_name = 'vgg16.pth'
    elif model == 'ResNet18':
        net = ResNet18()
        model_name = 'resnet18.pth'
    elif model == 'ResNet34':
        net = ResNet34()
        model_name = 'resnet34.pth'
    elif model == 'ResNet50':
        net = ResNet50()
        model_name = 'resnet50.pth'
    else:
        net = ResNet101()
        model_name = 'resnet101.pth'
    return net, os.path.join(model_path, model_name)
예제 #5
0
파일: run_vinmec.py 프로젝트: tmquan/COVID
    def build_graph(self, image, label):
        image = image / 128.0 - 1.0

        if self.args.name == 'VGG16':
            logit, recon = VGG16(image, classes=self.args.types)
        elif self.args.name == 'ShuffleNet':
            logit = ShuffleNet(image, classes=self.args.types)
        elif self.args.name == 'ResNet101':
            logit, recon = ResNet101(image, mode=self.args.mode, classes=self.args.types)
        elif self.args.name == 'DenseNet121':
            logit, recon = DenseNet121(image, classes=self.args.types) 
        elif self.args.name == 'DenseNet169':
            logit, recon = DenseNet169(image, classes=self.args.types)
        elif self.args.name == 'DenseNet201':
            logit, recon = DenseNet201(image, classes=self.args.types)
        elif self.args.name == 'InceptionBN':
            logit = InceptionBN(image, classes=self.args.types)
        else:
            pass

        estim = tf.sigmoid(logit, name='estim')
        loss_xent = class_balanced_sigmoid_cross_entropy(logit, label, name='loss_xent')
    
        # Visualization
        visualize_tensors('image', [image], scale_func=lambda x: x * 128.0 + 128.0, 
                          max_outputs=max(64, self.args.batch))
        # Regularize the weight of model 
        wd_w = tf.train.exponential_decay(2e-4, get_global_step_var(),
                                          80000, 0.7, True)
        wd_cost = tf.multiply(wd_w, regularize_cost('.*/W', tf.nn.l2_loss), name='wd_cost')

        add_param_summary(('.*/W', ['histogram']))   # monitor W
        cost = tf.add_n([loss_xent, wd_cost], name='cost')
        add_moving_summary(loss_xent)
        add_moving_summary(wd_cost)
        add_moving_summary(cost)
        return cost
예제 #6
0
 def select(self, model, args):
     """
     Selector utility to create models from model directory
     :param model: which model to select. Currently choices are: (cnn | resnet | preact_resnet | densenet | wresnet)
     :return: neural network to be trained
     """
     if model == 'cnn':
         net = SimpleModel(in_shape=self.in_shape,
                           activation=args.activation,
                           num_classes=self.num_classes,
                           filters=args.filters,
                           strides=args.strides,
                           kernel_sizes=args.kernel_sizes,
                           linear_widths=args.linear_widths,
                           use_batch_norm=args.use_batch_norm)
     else:
         assert (args.dataset != 'MNIST' and args.dataset != 'Fashion-MNIST'), \
             "Cannot use resnet or densenet for mnist style data"
         if model == 'resnet':
             assert args.resdepth in [18, 34, 50, 101, 152], \
                 "Non-standard and unsupported resnet depth ({})".format(args.resdepth)
             if args.resdepth == 18:
                 net = ResNet18(self.num_classes)
             elif args.resdepth == 34:
                 net = ResNet34(self.num_classes)
             elif args.resdepth == 50:
                 net = ResNet50(self.num_classes)
             elif args.resdepth == 101:
                 net = ResNet101(self.num_classes)
             else:
                 net = ResNet152()
         elif model == 'densenet':
             assert args.resdepth in [121, 161, 169, 201], \
                 "Non-standard and unsupported densenet depth ({})".format(args.resdepth)
             if args.resdepth == 121:
                 net = DenseNet121(
                     growth_rate=12, num_classes=self.num_classes
                 )  # NB NOTE: growth rate controls cifar implementation
             elif args.resdepth == 161:
                 net = DenseNet161(growth_rate=12,
                                   num_classes=self.num_classes)
             elif args.resdepth == 169:
                 net = DenseNet169(growth_rate=12,
                                   num_classes=self.num_classes)
             else:
                 net = DenseNet201(growth_rate=12,
                                   num_classes=self.num_classes)
         elif model == 'preact_resnet':
             assert args.resdepth in [18, 34, 50, 101, 152], \
                 "Non-standard and unsupported preact resnet depth ({})".format(args.resdepth)
             if args.resdepth == 18:
                 net = PreActResNet18(self.num_classes)
             elif args.resdepth == 34:
                 net = PreActResNet34(self.num_classes)
             elif args.resdepth == 50:
                 net = PreActResNet50(self.num_classes)
             elif args.resdepth == 101:
                 net = PreActResNet101(self.num_classes)
             else:
                 net = PreActResNet152()
         elif model == 'wresnet':
             assert ((args.resdepth - 4) % 6 == 0), \
                 "Wideresnet depth of {} not supported, must fulfill: (depth - 4) % 6 = 0".format(args.resdepth)
             net = WideResNet(depth=args.resdepth,
                              num_classes=self.num_classes,
                              widen_factor=args.widen_factor)
         else:
             raise NotImplementedError(
                 'Model {} not supported'.format(model))
     return net
예제 #7
0
    def select(self, model, path_fc=False, upsample='pixel'):
        if model == 'cnn':
            net = SimpleModel(
                in_shape=self.in_shape,
                activation=self.activation,
                num_classes=self.num_classes,
                filters=self.filters,
            )
        else:
            assert (self.dataset != 'MNIST' and self.dataset != 'Fashion-MNIST'
                    ), "Cannot use resnet or densenet for mnist style data"
            if model == 'resnet':
                assert self.resdepth in [
                    18, 34, 50, 101, 152
                ], "Non-standard and unsupported resnet depth ({})".format(
                    self.resdepth)
                if self.resdepth == 18:
                    net = ResNet18()
                elif self.resdepth == 34:
                    net = ResNet34()
                elif self.resdepth == 50:
                    net = ResNet50()
                elif self.resdepth == 101:
                    net = ResNet101()
                else:
                    net = ResNet152()
            elif model == 'densenet':
                assert self.resdepth in [
                    121, 161, 169, 201
                ], "Non-standard and unsupported densenet depth ({})".format(
                    self.resdepth)
                if self.resdepth == 121:
                    net = DenseNet121()
                elif self.resdepth == 161:
                    net = DenseNet161()
                elif self.resdepth == 169:
                    net = DenseNet169()
                else:
                    net = DenseNet201()
            elif model == 'preact_resnet':
                assert self.resdepth in [
                    10, 18, 34, 50, 101, 152
                ], "Non-standard and unsupported preact resnet depth ({})".format(
                    self.resdepth)
                if self.resdepth == 10:
                    net = PreActResNet10(path_fc=path_fc,
                                         num_classes=self.num_classes,
                                         upsample=upsample)
                elif self.resdepth == 18:
                    net = PreActResNet18()
                elif self.resdepth == 34:
                    net = PreActResNet34()
                elif self.resdepth == 50:
                    net = PreActResNet50()
                elif self.resdepth == 101:
                    net = PreActResNet101()
                else:
                    net = PreActResNet152()
            elif model == 'wresnet':
                assert (
                    (self.resdepth - 4) % 6 == 0
                ), "Wideresnet depth of {} not supported, must fulfill: (depth - 4) % 6 = 0".format(
                    self.resdepth)
                net = WideResNet(depth=self.resdepth,
                                 num_classes=self.num_classes,
                                 widen_factor=self.widen_factor)

        return net
예제 #8
0
파일: __init__.py 프로젝트: llucid-97/AdaS
def get_network(name: str, num_classes: int) -> None:
    return \
        AlexNet(
            num_classes=num_classes) if name == 'AlexNet' else\
        DenseNet201(
            num_classes=num_classes) if name == 'DenseNet201' else\
        DenseNet169(
            num_classes=num_classes) if name == 'DenseNet169' else\
        DenseNet161(
            num_classes=num_classes) if name == 'DenseNet161' else\
        DenseNet121(
            num_classes=num_classes) if name == 'DenseNet121' else\
        DenseNet121CIFAR(
            num_classes=num_classes) if name == 'DenseNet121CIFAR' else\
        GoogLeNet(
            num_classes=num_classes) if name == 'GoogLeNet' else\
        InceptionV3(
            num_classes=num_classes) if name == 'InceptionV3' else\
        MNASNet_0_5(
            num_classes=num_classes) if name == 'MNASNet_0_5' else\
        MNASNet_0_75(
            num_classes=num_classes) if name == 'MNASNet_0_75' else\
        MNASNet_1(
            num_classes=num_classes) if name == 'MNASNet_1' else\
        MNASNet_1_3(
            num_classes=num_classes) if name == 'MNASNet_1_3' else\
        MobileNetV2(
            num_classes=num_classes) if name == 'MobileNetV2' else\
        ResNet18(
            num_classes=num_classes) if name == 'ResNet18' else\
        ResNet34(
            num_classes=num_classes) if name == 'ResNet34' else\
        ResNet34CIFAR(
            num_classes=num_classes) if name == 'ResNet34CIFAR' else\
        ResNet50CIFAR(
            num_classes=num_classes) if name == 'ResNet50CIFAR' else\
        ResNet101CIFAR(
            num_classes=num_classes) if name == 'ResNet101CIFAR' else\
        ResNet18CIFAR(
            num_classes=num_classes) if name == 'ResNet18CIFAR' else\
        ResNet50(
            num_classes=num_classes) if name == 'ResNet50' else\
        ResNet101(
            num_classes=num_classes) if name == 'ResNet101' else\
        ResNet152(
            num_classes=num_classes) if name == 'ResNet152' else\
        ResNeXt50(
            num_classes=num_classes) if name == 'ResNext50' else\
        ResNeXtCIFAR(
            num_classes=num_classes) if name == 'ResNeXtCIFAR' else\
        ResNeXt101(
            num_classes=num_classes) if name == 'ResNext101' else\
        WideResNet50(
            num_classes=num_classes) if name == 'WideResNet50' else\
        WideResNet101(
            num_classes=num_classes) if name == 'WideResNet101' else\
        ShuffleNetV2_0_5(
            num_classes=num_classes) if name == 'ShuffleNetV2_0_5' else\
        ShuffleNetV2_1(
            num_classes=num_classes) if name == 'ShuffleNetV2_1' else\
        ShuffleNetV2_1_5(
            num_classes=num_classes) if name == 'ShuffleNetV2_1_5' else\
        ShuffleNetV2_2(
            num_classes=num_classes) if name == 'ShuffleNetV2_2' else\
        SqueezeNet_1(
            num_classes=num_classes) if name == 'SqueezeNet_1' else\
        SqueezeNet_1_1(
            num_classes=num_classes) if name == 'SqueezeNet_1_1' else\
        VGG11(
            num_classes=num_classes) if name == 'VGG11' else\
        VGG11_BN(
            num_classes=num_classes) if name == 'VGG11_BN' else\
        VGG13(
            num_classes=num_classes) if name == 'VGG13' else\
        VGG13_BN(
            num_classes=num_classes) if name == 'VGG13_BN' else\
        VGG16(
            num_classes=num_classes) if name == 'VGG16' else\
        VGG16_BN(
            num_classes=num_classes) if name == 'VGG16_BN' else\
        VGG19(
            num_classes=num_classes) if name == 'VGG19' else\
        VGG19_BN(
            num_classes=num_classes) if name == 'VGG19_BN' else \
        VGGCIFAR('VGG16',
                 num_classes=num_classes) if name == 'VGG16CIFAR' else \
        EfficientNetB4(
            num_classes=num_classes) if name == 'EfficientNetB4' else \
        EfficientNetB0CIFAR(
            num_classes=num_classes) if name == 'EfficientNetB0CIFAR' else\
        None
예제 #9
0
파일: main.py 프로젝트: yukuotc/PaddleHub
def train():
    dataset = args.dataset
    image_shape = [3, 224, 224]
    pretrained_model = args.pretrained_model

    class_map_path = f'{global_data_path}/{dataset}/readable_label.txt'

    if os.path.exists(class_map_path):
        logger.info(
            "The map of readable label and numerical label has been found!")
        with open(class_map_path) as f:
            label_dict = {}
            strinfo = re.compile(r"\d+ ")
            for item in f.readlines():
                key = int(item.split(" ")[0])
                value = [
                    strinfo.sub("", l).replace("\n", "")
                    for l in item.split(", ")
                ]
                label_dict[key] = value[0]

    assert os.path.isdir(
        pretrained_model), "please load right pretrained model path for infer"

    # data reader
    batch_size = args.batch_size
    reader_config = ReaderConfig(f'{global_data_path}/{dataset}', is_test=False)
    reader = reader_config.get_reader()
    train_reader = paddle.batch(
        paddle.reader.shuffle(reader, buf_size=batch_size),
        batch_size,
        drop_last=True)

    # model ops
    image = fluid.data(
        name='image', shape=[None] + image_shape, dtype='float32')
    label = fluid.data(name='label', shape=[None, 1], dtype='int64')
    model = ResNet101(is_test=False)
    features, logits = model.net(
        input=image, class_dim=reader_config.num_classes)
    out = fluid.layers.softmax(logits)

    # loss, metric
    cost = fluid.layers.mean(fluid.layers.cross_entropy(out, label))
    accuracy = fluid.layers.accuracy(input=out, label=label)

    # delta regularization
    # teacher model pre-trained on Imagenet, 1000 classes.
    global_name = 't_'
    t_model = ResNet101(is_test=True, global_name=global_name)
    t_features, _ = t_model.net(input=image, class_dim=1000)
    for f in t_features.keys():
        t_features[f].stop_gradient = True

    # delta loss. hard code for the layer name, which is just before global pooling.
    delta_loss = fluid.layers.square(t_features['t_res5c.add.output.5.tmp_0'] -
                                     features['res5c.add.output.5.tmp_0'])
    delta_loss = fluid.layers.reduce_mean(delta_loss)

    params = fluid.default_main_program().global_block().all_parameters()
    parameters = []
    for param in params:
        if param.trainable:
            if global_name in param.name:
                print('\tfixing', param.name)
            else:
                print('\ttraining', param.name)
                parameters.append(param.name)

    # optimizer, with piecewise_decay learning rate.
    total_steps = len(reader_config.image_paths) * args.num_epoch // batch_size
    boundaries = [int(total_steps * 2 / 3)]
    print('\ttotal learning steps:', total_steps)
    print('\tlr decays at:', boundaries)
    values = [0.01, 0.001]
    optimizer = fluid.optimizer.Momentum(
        learning_rate=fluid.layers.piecewise_decay(
            boundaries=boundaries, values=values),
        momentum=0.9,
        parameter_list=parameters,
        regularization=fluid.regularizer.L2Decay(args.wd_rate))
    cur_lr = optimizer._global_learning_rate()

    optimizer.minimize(
        cost + args.delta_reg * delta_loss, parameter_list=parameters)

    # data reader
    feed_order = ['image', 'label']

    # executor (session)
    place = fluid.CUDAPlace(
        args.use_cuda) if args.use_cuda >= 0 else fluid.CPUPlace()
    exe = fluid.Executor(place)

    # running
    main_program = fluid.default_main_program()
    start_program = fluid.default_startup_program()

    feed_var_list_loop = [
        main_program.global_block().var(var_name) for var_name in feed_order
    ]
    feeder = fluid.DataFeeder(feed_list=feed_var_list_loop, place=place)
    exe.run(start_program)

    loading_parameters = {}
    t_loading_parameters = {}
    for p in main_program.all_parameters():
        if 'fc' not in p.name:
            if global_name in p.name:
                new_name = os.path.join(pretrained_model,
                                        p.name.split(global_name)[-1])
                t_loading_parameters[new_name] = p
                print(new_name, p.name)
            else:
                name = os.path.join(pretrained_model, p.name)
                loading_parameters[name] = p
                print(name, p.name)
        else:
            print(f'not loading {p.name}')

    load_vars_by_dict(exe, loading_parameters, main_program=main_program)
    load_vars_by_dict(exe, t_loading_parameters, main_program=main_program)

    step = 0

    # test_data = reader_creator_all_in_memory('./datasets/PetImages', is_test=True)
    for e_id in range(args.num_epoch):
        avg_delta_loss = AverageMeter()
        avg_loss = AverageMeter()
        avg_accuracy = AverageMeter()
        batch_time = AverageMeter()
        end = time.time()

        for step_id, data_train in enumerate(train_reader()):
            wrapped_results = exe.run(
                main_program,
                feed=feeder.feed(data_train),
                fetch_list=[cost, accuracy, delta_loss, cur_lr])
            # print(avg_loss_value[2])
            batch_time.update(time.time() - end)
            end = time.time()

            avg_loss.update(wrapped_results[0][0], len(data_train))
            avg_accuracy.update(wrapped_results[1][0], len(data_train))
            avg_delta_loss.update(wrapped_results[2][0], len(data_train))
            if step % 100 == 0:
                print(
                    f"\tEpoch {e_id}, Global_Step {step}, Batch_Time {batch_time.avg: .2f},"
                    f" LR {wrapped_results[3][0]}, "
                    f"Loss {avg_loss.avg: .4f}, Acc {avg_accuracy.avg: .4f}, Delta_Loss {avg_delta_loss.avg: .4f}"
                )
            step += 1

        if args.outdir is not None:
            try:
                os.makedirs(args.outdir, exist_ok=True)
                fluid.io.save_params(
                    executor=exe, dirname=args.outdir + '/' + get_model_id())
            except:
                print('\t Not saving trained parameters.')

        if e_id == args.num_epoch - 1:
            print("kpis\ttrain_cost\t%f" % avg_loss.avg)
            print("kpis\ttrain_acc\t%f" % avg_accuracy.avg)
예제 #10
0
파일: main.py 프로젝트: yukuotc/PaddleHub
def test():
    image_shape = [3, 224, 224]
    pretrained_model = args.outdir + '/' + get_model_id()

    # data reader
    batch_size = args.batch_size
    reader_config = ReaderConfig(
        f'{global_data_path}/{args.dataset}', is_test=True)
    reader = reader_config.get_reader()
    test_reader = paddle.batch(reader, batch_size)

    # model ops
    image = fluid.data(
        name='image', shape=[None] + image_shape, dtype='float32')
    label = fluid.data(name='label', shape=[None, 1], dtype='int64')
    model = ResNet101(is_test=True)
    _, logits = model.net(input=image, class_dim=reader_config.num_classes)
    out = fluid.layers.softmax(logits)

    # loss, metric
    cost = fluid.layers.mean(fluid.layers.cross_entropy(out, label))
    accuracy = fluid.layers.accuracy(input=out, label=label)

    # data reader
    feed_order = ['image', 'label']

    # executor (session)
    place = fluid.CUDAPlace(
        args.use_cuda) if args.use_cuda >= 0 else fluid.CPUPlace()
    exe = fluid.Executor(place)

    # running
    main_program = fluid.default_main_program()
    start_program = fluid.default_startup_program()

    feed_var_list_loop = [
        main_program.global_block().var(var_name) for var_name in feed_order
    ]
    feeder = fluid.DataFeeder(feed_list=feed_var_list_loop, place=place)
    exe.run(start_program)

    fluid.io.load_params(exe, pretrained_model)

    step = 0
    avg_loss = AverageMeter()
    avg_accuracy = AverageMeter()

    for step_id, data_train in enumerate(test_reader()):
        avg_loss_value = exe.run(
            main_program,
            feed=feeder.feed(data_train),
            fetch_list=[cost, accuracy])
        avg_loss.update(avg_loss_value[0], len(data_train))
        avg_accuracy.update(avg_loss_value[1], len(data_train))
        if step_id % 10 == 0:
            print("\nBatch %d, Loss %f, Acc %f" % (step_id, avg_loss.avg,
                                                   avg_accuracy.avg))
        step += 1

    print("test counts:", avg_loss.count)
    print("test_cost\t%f" % avg_loss.avg)
    print("test_acc\t%f" % avg_accuracy.avg)
예제 #11
0
def main(args):

    check_path(args)

    # CIFAR-10的全部类别,一共10类
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
               'ship', 'truck')

    # 数据集
    data_builder = DataBuilder(args)
    dataSet = DataSet(data_builder.train_builder(),
                      data_builder.test_builder(), classes)

    # 选择模型
    if args.lenet:
        net = LeNet()
        model_name = args.name_le
    elif args.vgg:
        net = Vgg16_Net()
        model_name = args.name_vgg
    elif args.resnet18:
        net = ResNet18()
        model_name = args.name_res18
    elif args.resnet34:
        net = ResNet34()
        model_name = args.name_res34
    elif args.resnet50:
        net = ResNet50()
        model_name = args.name_res50
    elif args.resnet101:
        net = ResNet101()
        model_name = args.name_res101
    elif args.resnet152:
        net = ResNet152()
        model_name = args.name_res152

    # 交叉熵损失函数
    criterion = nn.CrossEntropyLoss()

    # SGD优化器
    optimizer = optim.SGD(net.parameters(),
                          lr=args.learning_rate,
                          momentum=args.sgd_momentum,
                          weight_decay=args.weight_decay)

    # 余弦退火调整学习率
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=150)

    # 模型的参数保存路径
    model_path = os.path.join(args.model_path, model_name)

    # 启动训练
    if args.do_train:
        print("Training...")

        trainer = Trainer(net, criterion, optimizer, scheduler,
                          dataSet.train_loader, dataSet.test_loader,
                          model_path, args)

        trainer.train(epochs=args.epoch)
        # t.save(net.state_dict(), model_path)

    # 启动测试,如果--do_train也出现,则用刚刚训练的模型进行测试
    # 否则就使用已保存的模型进行测试
    if args.do_eval:
        if not args.do_train and not os.path.exists(model_path):
            print(
                "Sorry, there's no saved model yet, you need to train first.")
            return
        # --do_eval
        if not args.do_train:
            checkpoint = t.load(model_path)
            net.load_state_dict(checkpoint['net'])
            accuracy = checkpoint['acc']
            epoch = checkpoint['epoch']
            print("Using saved model, accuracy : %f  epoch: %d" %
                  (accuracy, epoch))
        tester = Tester(dataSet.test_loader, net, args)
        tester.test()

    if args.show_model:
        if not os.path.exists(model_path):
            print(
                "Sorry, there's no saved model yet, you need to train first.")
            return
        show_model(args)

    if args.do_predict:
        device = t.device("cuda" if t.cuda.is_available() else "cpu")
        checkpoint = t.load(model_path, map_location=device)
        net.load_state_dict(checkpoint['net'])
        predictor = Predictor(net, classes)
        img_path = 'test'
        img_name = [os.path.join(img_path, x) for x in os.listdir(img_path)]
        for img in img_name:
            predictor.predict(img)