def get_fe(fe, input_image): if fe == 'effnetb0': return EfficientNetB0(input_tensor=input_image, include_top=False), [ 'swish_last', 'block5_i_MB_swish_1', 'block3_i_MB_swish_1' ] elif fe == 'effnetb1': return EfficientNetB1(input_tensor=input_image, include_top=False), [ 'swish_last', 'block5_i_MB_swish_1', 'block3_i_MB_swish_1' ] elif fe == 'effnetb2': return EfficientNetB2(input_tensor=input_image, include_top=False), [ 'swish_last', 'block5_i_MB_swish_1', 'block3_i_MB_swish_1' ] elif fe == 'effnetb3': return EfficientNetB3(input_tensor=input_image, include_top=False), [ 'swish_last', 'block5_i_MB_swish_1', 'block3_i_MB_swish_1' ] elif fe == 'effnetb4': return EfficientNetB4(input_tensor=input_image, include_top=False), [ 'swish_last', 'block5_i_MB_swish_1', 'block3_i_MB_swish_1' ] elif fe == 'effnetb5': return EfficientNetB5(input_tensor=input_image, include_top=False), [ 'swish_last', 'block5_i_MB_swish_1', 'block3_i_MB_swish_1' ] elif fe == 'd53': return d53(input_image) elif fe == 'mnetv2': mnet = MobileNetV2(input_tensor=input_image, weights='imagenet') return mnet, [ 'out_relu', 'block_13_expand_relu', 'block_6_expand_relu' ] elif fe == 'mnet': mnet = MobileNet(input_tensor=input_image, weights='imagenet') return mnet, ['conv_pw_13_relu', 'conv_pw_11_relu', 'conv_pw_5_relu'] elif fe == 'r50': r50 = ResNet50(input_tensor=input_image, weights='imagenet') return r50, ['activation_49', 'activation_40', 'activation_22'] raise ValueError('Pls put the correct fe')
def train(dataset='Fish4Knowledge', cnn='resnet50', batch_size=32, epochs=50, image_size=32, eps=0.01): """ Train one checkpoint with data augmentation: random padding+cropping and horizontal flip :param args: :return: """ print( 'Dataset: %s, CNN: %s, loss: adv, epsilon: %.3f batch: %s, epochs: %s' % (dataset, cnn, eps, batch_size, epochs)) IMAGE_SIZE = image_size INPUT_SHAPE = (image_size, image_size, 3) # find image folder: images are distributed in class subfolders if dataset == 'Fish4Knowledge': # image_path = '/home/xingjun/datasets/Fish4Knowledge/fish_image' image_path = '/data/cephfs/punim0619/Fish4Knowledge/fish_image' images, labels, images_val, labels_val = get_Fish4Knowledge( image_path, train_ratio=0.8) elif dataset == 'QUTFish': # image_path = '/home/xingjun/datasets/QUT_fish_data' image_path = '/data/cephfs/punim0619/QUT_fish_data' images, labels, images_val, labels_val = get_QUTFish(image_path, train_ratio=0.8) elif dataset == 'WildFish': # image_pathes = ['/home/xingjun/datasets/WildFish/WildFish_part1', # '/home/xingjun/datasets/WildFish/WildFish_part2', # '/home/xingjun/datasets/WildFish/WildFish_part3', # '/home/xingjun/datasets/WildFish/WildFish_part4'] image_pathes = [ '/data/cephfs/punim0619/WildFish/WildFish_part1', '/data/cephfs/punim0619/WildFish/WildFish_part2', '/data/cephfs/punim0619/WildFish/WildFish_part3', '/data/cephfs/punim0619/WildFish/WildFish_part4' ] images, labels, images_val, labels_val = get_WildFish(image_pathes, train_ratio=0.8) # images, labels, images_val, labels_val = get_imagenet_googlesearch_data(image_path, num_class=NUM_CLASS) num_classes = len(np.unique(labels)) num_images = len(images) num_images_val = len(images_val) print('Train: classes: %s, images: %s, val images: %s' % (num_classes, num_images, num_images_val)) global current_index current_index = 0 # dynamic loading a batch of data def get_batch(): index = 1 global current_index B = np.zeros(shape=(batch_size, IMAGE_SIZE, IMAGE_SIZE, 3)) L = np.zeros(shape=(batch_size)) while index < batch_size: try: img = load_img(images[current_index], target_size=(IMAGE_SIZE, IMAGE_SIZE)) img = img_to_array(img) img /= 255. # if cnn == 'ResNet50': # imagenet pretrained # mean = np.array([0.485, 0.456, 0.406]) # std = np.array([0.229, 0.224, 0.225]) # img = (img - mean)/std ## data augmentation # random width and height shift img = random_shift(img, 0.2, 0.2) # random rotation img = random_rotation(img, 10) # random horizental flip flip_horizontal = (np.random.random() < 0.5) if flip_horizontal: img = flip_axis(img, axis=1) # # random vertical flip # flip_vertical = (np.random.random() < 0.5) # if flip_vertical: # img = flip_axis(img, axis=0) # #cutout # eraser = get_random_eraser(v_l=0, v_h=1, pixel_level=False) # img = eraser(img) B[index] = img L[index] = labels[current_index] index = index + 1 current_index = current_index + 1 except: traceback.print_exc() # print("Ignore image {}".format(images[current_index])) current_index = current_index + 1 # B = np.rollaxis(B, 3, 1) return B, np_utils.to_categorical(L, num_classes) global val_current_index val_current_index = 0 # dynamic loading a batch of validation data def get_val_batch(): index = 1 B = np.zeros(shape=(batch_size, IMAGE_SIZE, IMAGE_SIZE, 3)) L = np.zeros(shape=(batch_size)) global val_current_index while index < batch_size: try: img = load_img(images_val[val_current_index], target_size=(IMAGE_SIZE, IMAGE_SIZE)) img = img_to_array(img) img /= 255. # if cnn == 'ResNet50': # imagenet pretrained # mean = np.array([0.485, 0.456, 0.406]) # std = np.array([0.229, 0.224, 0.225]) # img = (img - mean)/std B[index] = img L[index] = labels_val[val_current_index] index = index + 1 val_current_index = val_current_index + 1 except: traceback.print_exc() # print("Ignore image {}".format(images[val_current_index])) val_current_index = val_current_index + 1 # B = np.rollaxis(B, 3, 1) return B, np_utils.to_categorical(L, num_classes) # load checkpoint if cnn == 'ResNet18': base_model = ResNet18(input_shape=INPUT_SHAPE, classes=num_classes, include_top=False) elif cnn == 'ResNet34': base_model = ResNet34(input_shape=INPUT_SHAPE, classes=num_classes, include_top=False) elif cnn == 'ResNet50': base_model = ResNet50(include_top=False, weights='imagenet', input_shape=INPUT_SHAPE) elif cnn == 'EfficientNetB1': base_model = EfficientNetB1(input_shape=INPUT_SHAPE, classes=num_classes, include_top=False, backend=keras.backend, layers=keras.layers, models=keras.models, utils=keras.utils) elif cnn == 'EfficientNetB2': base_model = EfficientNetB2(input_shape=INPUT_SHAPE, classes=num_classes, include_top=False, backend=keras.backend, layers=keras.layers, models=keras.models, utils=keras.utils) elif cnn == 'EfficientNetB3': base_model = EfficientNetB3(input_shape=INPUT_SHAPE, classes=num_classes, include_top=False, backend=keras.backend, layers=keras.layers, models=keras.models, utils=keras.utils) elif cnn == 'EfficientNetB4': base_model = EfficientNetB4(input_shape=INPUT_SHAPE, classes=num_classes, include_top=False, backend=keras.backend, layers=keras.layers, models=keras.models, utils=keras.utils) else: warnings.warn("Error: unrecognized dataset!") return x = base_model.output x = Flatten()(x) x = Dense(num_classes, name='dense')(x) output = Activation('softmax')(x) model = Model(input=base_model.input, output=output, name=cnn) # model.summary() loss = cross_entropy base_lr = 1e-2 sgd = SGD(lr=base_lr, decay=1e-6, momentum=0.9, nesterov=True) model.compile(loss=loss, optimizer=sgd, metrics=['accuracy']) # AdaFGSM attack for AdvFish training attack = AdaFGSM(model, epsilon=float(eps), random_start=True, loss_func='xent', clip_min=0., clip_max=1.) # PGD attack for AdvFish training, it reduces to FGSM when set nb_iter=1 # attack = LinfPGDAttack(model, # epsilon=float(eps), # eps_iter=float(eps), # nb_iter=1, # random_start=True, # loss_func='xent', # clip_min=0., # clip_max=1.) # always save your weights after training or during training # create folder if not exist if not os.path.exists('models/'): os.makedirs('models/') log_path = 'log/%s' % dataset if not os.path.exists(log_path): os.makedirs(log_path) ## loop the weight folder then load the lastest weight file continue training model_prefix = '%s_%s_%s_%.4f_' % (dataset, cnn, 'adv', eps) w_files = os.listdir('models/') existing_ep = 0 # for fl in w_files: # if model_prefix in fl: # ep = re.search(model_prefix+"(.+?).h5", fl).group(1) # if int(ep) > existing_ep: # existing_ep = int(ep) # # if existing_ep > 0: # weight_file = 'models/' + model_prefix + str(existing_ep) + ".h5" # print("load previous model weights from: ", weight_file) # model.load_weights(weight_file) # # log = np.load(os.path.join(log_path, 'train_log_%s_%s_%.3f.npy' % (cnn, 'adv', eps))) # # train_loss_log = log[0, :existing_ep+1].tolist() # train_acc_log = log[1, :existing_ep+1].tolist() # val_loss_log = log[2, :existing_ep+1].tolist() # val_acc_log = log[3, :existing_ep+1].tolist() # else: train_loss_log = [] train_acc_log = [] val_loss_log = [] val_acc_log = [] # dynamic training for ep in range(epochs - existing_ep): # cosine learning rate annealing eta_min = 1e-5 eta_max = base_lr lr = eta_min + (eta_max - eta_min) * (1 + math.cos(math.pi * ep / epochs)) / 2 K.set_value(model.optimizer.lr, lr) # # step-wise learning rate annealing # if ep in [int(epochs*0.5), int(epochs*0.75)]: # lr = K.get_value(model.optimizer.lr) # K.set_value(model.optimizer.lr, lr*.1) # print("lr decayed to {}".format(lr*.1)) current_index = 0 n_step = int(num_images / batch_size) pbar = tqdm(range(n_step)) for stp in pbar: b, l = get_batch() # adversarial denoising b_adv = attack.perturb(K.get_session(), b, l, batch_size) train_loss, train_acc = model.train_on_batch(b_adv, l) pbar.set_postfix(acc='%.4f' % train_acc, loss='%.4f' % train_loss) ## test acc and loss at each epoch val_current_index = 0 y_pred = [] y_true = [] while val_current_index + batch_size < num_images_val: b, l = get_val_batch() pred = model.predict(b) y_pred.extend(pred.tolist()) y_true.extend(l.tolist()) y_pred = np.clip(np.array(y_pred), 1e-7, 1.) correct_pred = (np.argmax(y_pred, axis=1) == np.argmax(y_true, axis=1)) val_acc = np.mean(correct_pred) val_loss = -np.sum(np.mean(y_true * np.log(y_pred), axis=1)) / val_current_index train_loss_log.append(train_loss) train_acc_log.append(train_acc) val_loss_log.append(val_loss) val_acc_log.append(val_acc) log = np.stack((np.array(train_loss_log), np.array(train_acc_log), np.array(val_loss_log), np.array(val_acc_log))) # save training log np.save( os.path.join(log_path, 'train_log_%s_%s_%.4f.npy' % (cnn, 'adv', eps)), log) pbar.set_postfix(acc='%.4f' % train_acc, loss='%.4f' % train_loss, val_acc='%.4f' % val_acc, val_loss='%.4f' % val_loss) print( "Epoch %s - loss: %.4f - acc: %.4f - val_loss: %.4f - val_acc: %.4f" % (ep, train_loss, train_acc, val_loss, val_acc)) images, labels = shuffle(images, labels) if ((ep + existing_ep + 1) % 5 == 0) or (ep == (epochs - existing_ep - 1)): model_file = 'models/%s_%s_%s_%.4f_%s.h5' % (dataset, cnn, 'adv', eps, ep + existing_ep) model.save_weights(model_file)
def get_effnet_model(save_path, model_res=1024, image_size=256, depth=1, size=3, activation='elu', loss='logcosh', optimizer='adam'): if os.path.exists(save_path): print('Loading model') return load_model(save_path) # Build model print('Building model') model_scale = int(2 * (math.log(model_res, 2) - 1)) # For example, 1024 -> 18 if (size <= 0): effnet = EfficientNetB0(include_top=False, weights='imagenet', input_shape=(image_size, image_size, 3)) if (size == 1): effnet = EfficientNetB1(include_top=False, weights='imagenet', input_shape=(image_size, image_size, 3)) if (size == 2): effnet = EfficientNetB2(include_top=False, weights='imagenet', input_shape=(image_size, image_size, 3)) if (size >= 3): effnet = EfficientNetB3(include_top=False, weights='imagenet', input_shape=(image_size, image_size, 3)) layer_size = model_scale * 8 * 8 * 8 if is_square(layer_size): # work out layer dimensions layer_l = int(math.sqrt(layer_size) + 0.5) layer_r = layer_l else: layer_m = math.log(math.sqrt(layer_size), 2) layer_l = 2**math.ceil(layer_m) layer_r = layer_size // layer_l layer_l = int(layer_l) layer_r = int(layer_r) x_init = None inp = Input(shape=(image_size, image_size, 3)) x = effnet(inp) if (size < 1): x = Conv2D(model_scale * 8, 1, activation=activation)(x) # scale down if (depth > 0): x = Reshape((layer_r, layer_l))( x ) # See https://github.com/OliverRichter/TreeConnect/blob/master/cifar.py - TreeConnect inspired layers instead of dense layers. else: if (depth < 1): depth = 1 if (size <= 2): x = Conv2D(model_scale * 8 * 4, 1, activation=activation)(x) # scale down a bit x = Reshape((layer_r * 2, layer_l * 2))( x ) # See https://github.com/OliverRichter/TreeConnect/blob/master/cifar.py - TreeConnect inspired layers instead of dense layers. else: x = Reshape((384, 256))(x) # full size for B3 while (depth > 0): x = LocallyConnected1D(layer_r, 1, activation=activation)(x) x = Permute((2, 1))(x) x = LocallyConnected1D(layer_l, 1, activation=activation)(x) x = Permute((2, 1))(x) if x_init is not None: x = Add()([x, x_init]) # add skip connection x_init = x depth -= 1 if ( size >= 2 ): # add unshared layers at end for different sections of the latent space x_init = x if layer_r % 3 == 0 and layer_l % 3 == 0: a = LocallyConnected1D(layer_r, 1, activation=activation)(x) b = LocallyConnected1D(layer_r, 1, activation=activation)(x) c = LocallyConnected1D(layer_r, 1, activation=activation)(x) a = Permute((2, 1))(a) b = Permute((2, 1))(b) c = Permute((2, 1))(c) a = LocallyConnected1D(layer_l // 3, 1, activation=activation)(a) b = LocallyConnected1D(layer_l // 3, 1, activation=activation)(b) c = LocallyConnected1D(layer_l // 3, 1, activation=activation)(c) x = Concatenate()([a, b, c]) else: a = LocallyConnected1D(layer_r // 2, 1, activation=activation)(x) b = LocallyConnected1D(layer_r // 2, 1, activation=activation)(x) c = LocallyConnected1D(layer_r // 2, 1, activation=activation)(x) d = LocallyConnected1D(layer_r // 2, 1, activation=activation)(x) a = Permute((2, 1))(a) b = Permute((2, 1))(b) c = Permute((2, 1))(c) d = Permute((2, 1))(d) a = LocallyConnected1D(layer_l // 2, 1, activation=activation)(a) b = LocallyConnected1D(layer_l // 2, 1, activation=activation)(b) c = LocallyConnected1D(layer_l // 2, 1, activation=activation)(c) d = LocallyConnected1D(layer_l // 2, 1, activation=activation)(d) x = Concatenate()([a, b, c, d]) x = Add()([x, x_init]) # add skip connection x = Reshape((model_scale, 512))(x) # train against all dlatent values model = Model(inputs=inp, outputs=x) model.compile(loss=loss, metrics=[], optimizer=optimizer ) # By default: adam optimizer, logcosh used for loss. return model