예제 #1
0
def get_fe(fe, input_image):
    if fe == 'effnetb0':
        return EfficientNetB0(input_tensor=input_image, include_top=False), [
            'swish_last', 'block5_i_MB_swish_1', 'block3_i_MB_swish_1'
        ]
    elif fe == 'effnetb1':
        return EfficientNetB1(input_tensor=input_image, include_top=False), [
            'swish_last', 'block5_i_MB_swish_1', 'block3_i_MB_swish_1'
        ]
    elif fe == 'effnetb2':
        return EfficientNetB2(input_tensor=input_image, include_top=False), [
            'swish_last', 'block5_i_MB_swish_1', 'block3_i_MB_swish_1'
        ]
    elif fe == 'effnetb3':
        return EfficientNetB3(input_tensor=input_image, include_top=False), [
            'swish_last', 'block5_i_MB_swish_1', 'block3_i_MB_swish_1'
        ]
    elif fe == 'effnetb4':
        return EfficientNetB4(input_tensor=input_image, include_top=False), [
            'swish_last', 'block5_i_MB_swish_1', 'block3_i_MB_swish_1'
        ]
    elif fe == 'effnetb5':
        return EfficientNetB5(input_tensor=input_image, include_top=False), [
            'swish_last', 'block5_i_MB_swish_1', 'block3_i_MB_swish_1'
        ]
    elif fe == 'd53':
        return d53(input_image)
    elif fe == 'mnetv2':
        mnet = MobileNetV2(input_tensor=input_image, weights='imagenet')
        return mnet, [
            'out_relu', 'block_13_expand_relu', 'block_6_expand_relu'
        ]
    elif fe == 'mnet':
        mnet = MobileNet(input_tensor=input_image, weights='imagenet')
        return mnet, ['conv_pw_13_relu', 'conv_pw_11_relu', 'conv_pw_5_relu']
    elif fe == 'r50':
        r50 = ResNet50(input_tensor=input_image, weights='imagenet')
        return r50, ['activation_49', 'activation_40', 'activation_22']
    raise ValueError('Pls put the correct fe')
예제 #2
0
def train(dataset='Fish4Knowledge',
          cnn='resnet50',
          batch_size=32,
          epochs=50,
          image_size=32,
          eps=0.01):
    """
    Train one checkpoint with data augmentation: random padding+cropping and horizontal flip
    :param args: 
    :return:
    """
    print(
        'Dataset: %s, CNN: %s, loss: adv, epsilon: %.3f batch: %s, epochs: %s'
        % (dataset, cnn, eps, batch_size, epochs))

    IMAGE_SIZE = image_size
    INPUT_SHAPE = (image_size, image_size, 3)

    # find image folder: images are distributed in class subfolders
    if dataset == 'Fish4Knowledge':
        # image_path = '/home/xingjun/datasets/Fish4Knowledge/fish_image'
        image_path = '/data/cephfs/punim0619/Fish4Knowledge/fish_image'
        images, labels, images_val, labels_val = get_Fish4Knowledge(
            image_path, train_ratio=0.8)
    elif dataset == 'QUTFish':
        # image_path = '/home/xingjun/datasets/QUT_fish_data'
        image_path = '/data/cephfs/punim0619/QUT_fish_data'
        images, labels, images_val, labels_val = get_QUTFish(image_path,
                                                             train_ratio=0.8)
    elif dataset == 'WildFish':
        # image_pathes = ['/home/xingjun/datasets/WildFish/WildFish_part1',
        #             '/home/xingjun/datasets/WildFish/WildFish_part2',
        #             '/home/xingjun/datasets/WildFish/WildFish_part3',
        #             '/home/xingjun/datasets/WildFish/WildFish_part4']
        image_pathes = [
            '/data/cephfs/punim0619/WildFish/WildFish_part1',
            '/data/cephfs/punim0619/WildFish/WildFish_part2',
            '/data/cephfs/punim0619/WildFish/WildFish_part3',
            '/data/cephfs/punim0619/WildFish/WildFish_part4'
        ]
        images, labels, images_val, labels_val = get_WildFish(image_pathes,
                                                              train_ratio=0.8)

    # images, labels, images_val, labels_val = get_imagenet_googlesearch_data(image_path, num_class=NUM_CLASS)
    num_classes = len(np.unique(labels))
    num_images = len(images)
    num_images_val = len(images_val)
    print('Train: classes: %s, images: %s, val images: %s' %
          (num_classes, num_images, num_images_val))

    global current_index
    current_index = 0

    # dynamic loading a batch of data
    def get_batch():
        index = 1

        global current_index

        B = np.zeros(shape=(batch_size, IMAGE_SIZE, IMAGE_SIZE, 3))
        L = np.zeros(shape=(batch_size))
        while index < batch_size:
            try:
                img = load_img(images[current_index],
                               target_size=(IMAGE_SIZE, IMAGE_SIZE))
                img = img_to_array(img)
                img /= 255.
                # if cnn == 'ResNet50': # imagenet pretrained
                #     mean = np.array([0.485, 0.456, 0.406])
                #     std = np.array([0.229, 0.224, 0.225])
                #     img = (img - mean)/std
                ## data augmentation
                # random width and height shift
                img = random_shift(img, 0.2, 0.2)
                # random rotation
                img = random_rotation(img, 10)
                # random horizental flip
                flip_horizontal = (np.random.random() < 0.5)
                if flip_horizontal:
                    img = flip_axis(img, axis=1)
                # # random vertical flip
                # flip_vertical = (np.random.random() < 0.5)
                # if flip_vertical:
                #     img = flip_axis(img, axis=0)
                # #cutout
                # eraser = get_random_eraser(v_l=0, v_h=1, pixel_level=False)
                # img = eraser(img)

                B[index] = img
                L[index] = labels[current_index]
                index = index + 1
                current_index = current_index + 1
            except:
                traceback.print_exc()
                # print("Ignore image {}".format(images[current_index]))
                current_index = current_index + 1
        # B = np.rollaxis(B, 3, 1)
        return B, np_utils.to_categorical(L, num_classes)

    global val_current_index
    val_current_index = 0

    # dynamic loading a batch of validation data
    def get_val_batch():
        index = 1
        B = np.zeros(shape=(batch_size, IMAGE_SIZE, IMAGE_SIZE, 3))
        L = np.zeros(shape=(batch_size))

        global val_current_index

        while index < batch_size:
            try:
                img = load_img(images_val[val_current_index],
                               target_size=(IMAGE_SIZE, IMAGE_SIZE))
                img = img_to_array(img)
                img /= 255.
                # if cnn == 'ResNet50': # imagenet pretrained
                #     mean = np.array([0.485, 0.456, 0.406])
                #     std = np.array([0.229, 0.224, 0.225])
                #     img = (img - mean)/std
                B[index] = img
                L[index] = labels_val[val_current_index]
                index = index + 1
                val_current_index = val_current_index + 1
            except:
                traceback.print_exc()
                # print("Ignore image {}".format(images[val_current_index]))
                val_current_index = val_current_index + 1
        # B = np.rollaxis(B, 3, 1)
        return B, np_utils.to_categorical(L, num_classes)

    # load checkpoint
    if cnn == 'ResNet18':
        base_model = ResNet18(input_shape=INPUT_SHAPE,
                              classes=num_classes,
                              include_top=False)
    elif cnn == 'ResNet34':
        base_model = ResNet34(input_shape=INPUT_SHAPE,
                              classes=num_classes,
                              include_top=False)
    elif cnn == 'ResNet50':
        base_model = ResNet50(include_top=False,
                              weights='imagenet',
                              input_shape=INPUT_SHAPE)
    elif cnn == 'EfficientNetB1':
        base_model = EfficientNetB1(input_shape=INPUT_SHAPE,
                                    classes=num_classes,
                                    include_top=False,
                                    backend=keras.backend,
                                    layers=keras.layers,
                                    models=keras.models,
                                    utils=keras.utils)
    elif cnn == 'EfficientNetB2':
        base_model = EfficientNetB2(input_shape=INPUT_SHAPE,
                                    classes=num_classes,
                                    include_top=False,
                                    backend=keras.backend,
                                    layers=keras.layers,
                                    models=keras.models,
                                    utils=keras.utils)
    elif cnn == 'EfficientNetB3':
        base_model = EfficientNetB3(input_shape=INPUT_SHAPE,
                                    classes=num_classes,
                                    include_top=False,
                                    backend=keras.backend,
                                    layers=keras.layers,
                                    models=keras.models,
                                    utils=keras.utils)
    elif cnn == 'EfficientNetB4':
        base_model = EfficientNetB4(input_shape=INPUT_SHAPE,
                                    classes=num_classes,
                                    include_top=False,
                                    backend=keras.backend,
                                    layers=keras.layers,
                                    models=keras.models,
                                    utils=keras.utils)
    else:
        warnings.warn("Error: unrecognized dataset!")
        return

    x = base_model.output
    x = Flatten()(x)
    x = Dense(num_classes, name='dense')(x)
    output = Activation('softmax')(x)
    model = Model(input=base_model.input, output=output, name=cnn)
    # model.summary()

    loss = cross_entropy

    base_lr = 1e-2
    sgd = SGD(lr=base_lr, decay=1e-6, momentum=0.9, nesterov=True)
    model.compile(loss=loss, optimizer=sgd, metrics=['accuracy'])

    # AdaFGSM attack for AdvFish training
    attack = AdaFGSM(model,
                     epsilon=float(eps),
                     random_start=True,
                     loss_func='xent',
                     clip_min=0.,
                     clip_max=1.)
    # PGD attack for AdvFish training, it reduces to FGSM when set nb_iter=1
    # attack = LinfPGDAttack(model,
    #                      epsilon=float(eps),
    #                      eps_iter=float(eps),
    #                      nb_iter=1,
    #                      random_start=True,
    #                      loss_func='xent',
    #                      clip_min=0.,
    #                      clip_max=1.)

    # always save your weights after training or during training
    # create folder if not exist
    if not os.path.exists('models/'):
        os.makedirs('models/')
    log_path = 'log/%s' % dataset
    if not os.path.exists(log_path):
        os.makedirs(log_path)

    ## loop the weight folder then load the lastest weight file continue training
    model_prefix = '%s_%s_%s_%.4f_' % (dataset, cnn, 'adv', eps)
    w_files = os.listdir('models/')
    existing_ep = 0
    # for fl in w_files:
    #     if model_prefix in fl:
    #         ep = re.search(model_prefix+"(.+?).h5", fl).group(1)
    #         if int(ep) > existing_ep:
    #             existing_ep = int(ep)
    #
    # if existing_ep > 0:
    #     weight_file = 'models/' + model_prefix + str(existing_ep) + ".h5"
    #     print("load previous model weights from: ", weight_file)
    #     model.load_weights(weight_file)
    #
    #     log = np.load(os.path.join(log_path, 'train_log_%s_%s_%.3f.npy' % (cnn, 'adv', eps)))
    #
    #     train_loss_log = log[0, :existing_ep+1].tolist()
    #     train_acc_log = log[1, :existing_ep+1].tolist()
    #     val_loss_log = log[2, :existing_ep+1].tolist()
    #     val_acc_log = log[3, :existing_ep+1].tolist()
    # else:
    train_loss_log = []
    train_acc_log = []
    val_loss_log = []
    val_acc_log = []

    # dynamic training
    for ep in range(epochs - existing_ep):
        # cosine learning rate annealing
        eta_min = 1e-5
        eta_max = base_lr
        lr = eta_min + (eta_max -
                        eta_min) * (1 + math.cos(math.pi * ep / epochs)) / 2
        K.set_value(model.optimizer.lr, lr)
        # # step-wise learning rate annealing
        # if ep in [int(epochs*0.5), int(epochs*0.75)]:
        #     lr = K.get_value(model.optimizer.lr)
        #     K.set_value(model.optimizer.lr, lr*.1)
        #     print("lr decayed to {}".format(lr*.1))

        current_index = 0
        n_step = int(num_images / batch_size)
        pbar = tqdm(range(n_step))
        for stp in pbar:
            b, l = get_batch()
            # adversarial denoising
            b_adv = attack.perturb(K.get_session(), b, l, batch_size)
            train_loss, train_acc = model.train_on_batch(b_adv, l)
            pbar.set_postfix(acc='%.4f' % train_acc, loss='%.4f' % train_loss)

        ## test acc and loss at each epoch
        val_current_index = 0
        y_pred = []
        y_true = []
        while val_current_index + batch_size < num_images_val:
            b, l = get_val_batch()
            pred = model.predict(b)
            y_pred.extend(pred.tolist())
            y_true.extend(l.tolist())

        y_pred = np.clip(np.array(y_pred), 1e-7, 1.)
        correct_pred = (np.argmax(y_pred, axis=1) == np.argmax(y_true, axis=1))
        val_acc = np.mean(correct_pred)
        val_loss = -np.sum(np.mean(y_true * np.log(y_pred),
                                   axis=1)) / val_current_index

        train_loss_log.append(train_loss)
        train_acc_log.append(train_acc)
        val_loss_log.append(val_loss)
        val_acc_log.append(val_acc)
        log = np.stack((np.array(train_loss_log), np.array(train_acc_log),
                        np.array(val_loss_log), np.array(val_acc_log)))

        # save training log
        np.save(
            os.path.join(log_path,
                         'train_log_%s_%s_%.4f.npy' % (cnn, 'adv', eps)), log)

        pbar.set_postfix(acc='%.4f' % train_acc,
                         loss='%.4f' % train_loss,
                         val_acc='%.4f' % val_acc,
                         val_loss='%.4f' % val_loss)
        print(
            "Epoch %s - loss: %.4f - acc: %.4f - val_loss: %.4f - val_acc: %.4f"
            % (ep, train_loss, train_acc, val_loss, val_acc))
        images, labels = shuffle(images, labels)
        if ((ep + existing_ep + 1) % 5
                == 0) or (ep == (epochs - existing_ep - 1)):
            model_file = 'models/%s_%s_%s_%.4f_%s.h5' % (dataset, cnn, 'adv',
                                                         eps, ep + existing_ep)
            model.save_weights(model_file)
예제 #3
0
def get_effnet_model(save_path,
                     model_res=1024,
                     image_size=256,
                     depth=1,
                     size=3,
                     activation='elu',
                     loss='logcosh',
                     optimizer='adam'):

    if os.path.exists(save_path):
        print('Loading model')
        return load_model(save_path)

    # Build model
    print('Building model')
    model_scale = int(2 *
                      (math.log(model_res, 2) - 1))  # For example, 1024 -> 18
    if (size <= 0):
        effnet = EfficientNetB0(include_top=False,
                                weights='imagenet',
                                input_shape=(image_size, image_size, 3))
    if (size == 1):
        effnet = EfficientNetB1(include_top=False,
                                weights='imagenet',
                                input_shape=(image_size, image_size, 3))
    if (size == 2):
        effnet = EfficientNetB2(include_top=False,
                                weights='imagenet',
                                input_shape=(image_size, image_size, 3))
    if (size >= 3):
        effnet = EfficientNetB3(include_top=False,
                                weights='imagenet',
                                input_shape=(image_size, image_size, 3))

    layer_size = model_scale * 8 * 8 * 8
    if is_square(layer_size):  # work out layer dimensions
        layer_l = int(math.sqrt(layer_size) + 0.5)
        layer_r = layer_l
    else:
        layer_m = math.log(math.sqrt(layer_size), 2)
        layer_l = 2**math.ceil(layer_m)
        layer_r = layer_size // layer_l
    layer_l = int(layer_l)
    layer_r = int(layer_r)

    x_init = None
    inp = Input(shape=(image_size, image_size, 3))
    x = effnet(inp)
    if (size < 1):
        x = Conv2D(model_scale * 8, 1, activation=activation)(x)  # scale down
        if (depth > 0):
            x = Reshape((layer_r, layer_l))(
                x
            )  # See https://github.com/OliverRichter/TreeConnect/blob/master/cifar.py - TreeConnect inspired layers instead of dense layers.
    else:
        if (depth < 1):
            depth = 1
        if (size <= 2):
            x = Conv2D(model_scale * 8 * 4, 1,
                       activation=activation)(x)  # scale down a bit
            x = Reshape((layer_r * 2, layer_l * 2))(
                x
            )  # See https://github.com/OliverRichter/TreeConnect/blob/master/cifar.py - TreeConnect inspired layers instead of dense layers.
        else:
            x = Reshape((384, 256))(x)  # full size for B3
    while (depth > 0):
        x = LocallyConnected1D(layer_r, 1, activation=activation)(x)
        x = Permute((2, 1))(x)
        x = LocallyConnected1D(layer_l, 1, activation=activation)(x)
        x = Permute((2, 1))(x)
        if x_init is not None:
            x = Add()([x, x_init])  # add skip connection
        x_init = x
        depth -= 1
    if (
            size >= 2
    ):  # add unshared layers at end for different sections of the latent space
        x_init = x
        if layer_r % 3 == 0 and layer_l % 3 == 0:
            a = LocallyConnected1D(layer_r, 1, activation=activation)(x)
            b = LocallyConnected1D(layer_r, 1, activation=activation)(x)
            c = LocallyConnected1D(layer_r, 1, activation=activation)(x)
            a = Permute((2, 1))(a)
            b = Permute((2, 1))(b)
            c = Permute((2, 1))(c)
            a = LocallyConnected1D(layer_l // 3, 1, activation=activation)(a)
            b = LocallyConnected1D(layer_l // 3, 1, activation=activation)(b)
            c = LocallyConnected1D(layer_l // 3, 1, activation=activation)(c)
            x = Concatenate()([a, b, c])
        else:
            a = LocallyConnected1D(layer_r // 2, 1, activation=activation)(x)
            b = LocallyConnected1D(layer_r // 2, 1, activation=activation)(x)
            c = LocallyConnected1D(layer_r // 2, 1, activation=activation)(x)
            d = LocallyConnected1D(layer_r // 2, 1, activation=activation)(x)
            a = Permute((2, 1))(a)
            b = Permute((2, 1))(b)
            c = Permute((2, 1))(c)
            d = Permute((2, 1))(d)
            a = LocallyConnected1D(layer_l // 2, 1, activation=activation)(a)
            b = LocallyConnected1D(layer_l // 2, 1, activation=activation)(b)
            c = LocallyConnected1D(layer_l // 2, 1, activation=activation)(c)
            d = LocallyConnected1D(layer_l // 2, 1, activation=activation)(d)
            x = Concatenate()([a, b, c, d])
        x = Add()([x, x_init])  # add skip connection
    x = Reshape((model_scale, 512))(x)  # train against all dlatent values
    model = Model(inputs=inp, outputs=x)
    model.compile(loss=loss, metrics=[], optimizer=optimizer
                  )  # By default: adam optimizer, logcosh used for loss.
    return model