def _main(args):
    weights_path = args.weights
    train_data, valid_data = rbbox.load_data(args.dataset_path)
    print "Train Size: ", len(train_data)
    print "Valid Size: ", len(valid_data)
    target_size = (224, 224)
    batch_size = 32
    train_generator = rbbox.BatchGenerator(train_data,
                                           batch_size=batch_size,
                                           target_size=target_size,
                                           use_bbox=True)
    valid_generator = rbbox.BatchGenerator(valid_data,
                                           batch_size=batch_size,
                                           target_size=target_size,
                                           use_bbox=True)
    model_rbbox_regr = rbbox.get_model_rbbox_regressor(target_size)

    if not weights_path == None and os.path.isfile(weights_path):
        model_rbbox_regr.load_weights(weights_path)

    model_rbbox_regr.summary()

    optimizer = Adam(lr=0.0001, epsilon=1e-08, decay=0.0005)
    model_rbbox_regr.compile(loss='mean_squared_error',
                             optimizer=optimizer,
                             metrics=['accuracy'])

    scores = model_rbbox_regr.evaluate_generator(generator=valid_generator,
                                                 steps=len(valid_generator),
                                                 use_multiprocessing=True)

    for i in range(len(model_rbbox_regr.metrics_names)):
        print("%s: %.2f" % (model_rbbox_regr.metrics_names[i], scores[i]))

    checkpoint = ModelCheckpoint(weights_path,
                                 monitor='val_loss',
                                 verbose=1,
                                 save_best_only=True,
                                 mode='min',
                                 period=1)

    checkpoint.best = scores[0]

    tensorboard = TensorBoard(log_dir=os.path.expanduser('~/logs/'),
                              histogram_freq=0,
                              write_graph=True,
                              write_images=False)

    model_rbbox_regr.fit_generator(generator=train_generator,
                                   steps_per_epoch=len(train_generator) * 8,
                                   epochs=32,
                                   verbose=1,
                                   validation_data=valid_generator,
                                   validation_steps=len(valid_generator),
                                   callbacks=[checkpoint, tensorboard])
# checkpoint

filepath = os.path.join(model_dir,
                        checkpoint_name)  # name for the best model weights
# verbose = inoformation showing in training stage
checkpoint = ModelCheckpoint(filepath,
                             monitor='val_acc',
                             verbose=1,
                             save_best_only=True,
                             mode='max')

#print("checkpoint best =", best_weight_file_name[-11:-5])
#exit() # To test only

# get the saved model best accuracy value from file name
checkpoint.best = float(best_weight_file_name[-11:-5]
                        )  ## manually set best checkpoint accuracy or loss
callbacks_list = [checkpoint]

#############################################################################
# Evaluate available model to update callback list

#############################################################################

history = model.fit_generator(train_generator,
                              train_generator.n // batch_size,
                              epochs=number_of_epochs,
                              workers=4,
                              validation_data=validation_generator,
                              validation_steps=validation_generator.n //
                              batch_size,
                              callbacks=callbacks_list)
Example #3
0
reduce_lr = ReduceLROnPlateau(factor=0.7, patience=4, verbose=1, min_lr=1e-6)
callbacks_list = [checkpoint, reduce_lr]

#model.summary()

trainable_layer = [199] if bert_version == 24 else [103,95] #, 87, 71, 55]
epoch_num = [80, 40, 40, 8 , 4]
batch_size = [8, 8, 8, 8, 4]
resume = True
st = -1
if resume:
    if os.path.exists(opt_filepath+'.rc'):
        print('\033[32;1mLoad Model\033[0m')
        with open(opt_filepath+'.rc', 'r') as f:
            st = int(f.readline())
            checkpoint.best = float(f.readline())
    for l in range(len(trainable_layer))[st+1:]:
        for i, layer in enumerate(model.layers):
            if i > trainable_layer[l]:
                layer.trainable = True
                print(layer.name, layer.trainable)
            else: 
                layer.trainable = False

        model.compile(loss=f1_loss, optimizer = Adam(1e-4), metrics = [f1_acc, 'acc'])
        if os.path.exists(opt_filepath):
            model.load_weights(opt_filepath)

        model.fit([X_train, seg_train, citation_count_train], Y_train[:, :-1], batch_size=batch_size[l], epochs=epoch_num[l], callbacks=callbacks_list, validation_data=([X_val, seg_val, citation_count_val], Y_val[:, :-1]))

        with open(opt_filepath+'.rc', 'w') as f:
Example #4
0
    if not os.path.exists(PathOutput):
        os.makedirs(PathOutput)
    else:
        for dirName, subdirList, fileList in os.walk(PathOutput):
            for filename in fileList:
                os.remove(PathOutput+filename)
    
    logfile=PathOutput+'allnode_PIN.log'
    csv_logger = CSVLogger(logfile)
    filename="weights.{epoch:03d}-{val_loss:.2f}.hdf5"
    checkpointer = ModelCheckpoint(monitor='val_loss', filepath=PathOutput+filename, verbose=1, save_best_only=True, save_weights_only=True)
    
    #model.set_weights(init_weights)
     
    checkpointer.epochs_since_last_save = 0
    checkpointer.best = np.Inf
    lrate = LearningRateScheduler(my_learning_rate,verbose=1)
    model.fit(Data_0, Labels_0, epochs=epochs, batch_size=batch_size, validation_data=([Data_1,Labels_1]), 
              callbacks=[checkpointer,csv_logger,lrate], verbose=2, class_weight=weight)
    
    ##----------------------- Look for the best model to evaluate
    
    for dirName, subdirList, fileList in os.walk(PathOutput):
        fileList.sort()
    tmp=fileList[len(fileList)-1]
    print(tmp)
    filename=PathOutput+tmp
    model.load_weights(filename)
    model.save(PathOutput+'best_model.hd5')

else:
Example #5
0
    model_checkpoint = ModelCheckpoint(args.model, save_best_only=True, verbose=2)
    callbacks = [early_stopping, model_checkpoint]
    if args.history is not None:
        csv_logger = CSVLogger(args.history, append=True)
        callbacks.append(csv_logger)
    
    # Load model
    custom_objects = dict(inspect.getmembers(losses, inspect.isfunction))
    model = models.load_model(args.model, custom_objects=custom_objects)
    model.summary()

    # Get score
    if args.cont:
        losses = model.evaluate_generator(val_generator, val_batches)
        val_loss_idx = model.metrics_names.index('loss')
        print('Loaded model with: %s' % ' - '.join( \
                'val_%s: %0.4f' % (metric_name, loss) \
                 for metric_name, loss in zip(model.metrics_names, losses)))
        
        # Update callbacks
        model_checkpoint.best = losses[val_loss_idx]
        early_stopping.best = losses[val_loss_idx]
        
    model.fit_generator(train_generator, \
            train_batches, \
            epochs=args.epochs, \
            callbacks=callbacks, \
            validation_data=val_generator, \
            validation_steps=val_batches)

Example #6
0
def apricot_plus(model, model_weights_dir, dataset, adjustment_strategy):
    x_train, x_test, y_train, y_test = load_dataset(dataset)
    x_train_val, x_val, y_train_val, y_val = split_validation_dataset(x_train, y_train)

    train_size = len(x_train)
    val_size = len(x_train_val)
    test_size = len(x_test)

    fixed_model = model
    submodel_dir = os.path.join(model_weights_dir, 'submodels')
    trained_weights_path = os.path.join(model_weights_dir, 'trained.h5')
    fixed_weights_path = os.path.join(model_weights_dir, 'apricot_plus_fixed_{}.h5'.format(adjustment_strategy))
    log_path = os.path.join(model_weights_dir, 'apricot_plus_{}.log'.format(adjustment_strategy))

    if not os.path.exists(fixed_weights_path):
        fixed_model.save_weights(fixed_weights_path)

    datagen = ImageDataGenerator(horizontal_flip=True,
                                 width_shift_range=0.125,
                                 height_shift_range=0.125,
                                 fill_mode='constant', cval=0.)
    datagen.fit(x_train)

    fixed_model.load_weights(trained_weights_path)

    start = datetime.now()
    logger('---------------original model---------------', log_path)
    _, base_train_acc = fixed_model.evaluate(x_train_val, y_train_val)
    _, base_val_acc = fixed_model.evaluate(x_val, y_val)
    _, base_test_acc = fixed_model.evaluate(x_test, y_test)

    logger('train acc: {:.4f}, val acc: {:.4f}, test acc: {:.4f}'.format(base_train_acc, base_val_acc, base_test_acc),
           log_path)
    # to simply the process, get the classification results of submodels first.
    # do not shuffle the training dataset.
    fail_xs, fail_ys, fail_ys_label, fail_num, fail_index = get_indexed_failing_cases(fixed_model, x_train, y_train)

    print('getting sub correct matrix...')
    sub_correct_matrix_path = os.path.join(model_weights_dir, 'corr_mat_{}.npy'.format(NUM_SUBMODELS))
    if not os.path.exists(sub_correct_matrix_path):
        print('generating matrix....')
        sub_correct_mat = apricot_cal_sub_corr_mat(fixed_model, submodel_dir, fail_xs, fail_ys, num_submodels=NUM_SUBMODELS)
        np.save(sub_correct_matrix_path, sub_correct_mat)
    else:
        print('loading matrix...')
        sub_correct_mat = np.load(sub_correct_matrix_path)

    fixed_model.load_weights(trained_weights_path)

    weights_list = get_weights_list(fixed_model, submodel_dir, NUM_SUBMODELS)

    best_train_acc = base_train_acc
    best_val_acc = base_val_acc
    best_test_acc = base_test_acc

    # Apricot Plus: iterates failing cases.
    print('start the main iteration process...')
    for count in range(LOOP_COUNT):  # iterate 3 times.
        np.random.shuffle(sub_correct_mat)
        # calculate the iteration number.
        iter_count, res = divmod(sub_correct_mat.shape[0], FIX_BATCH_SIZE)
        if res != 0:
            iter_count += 1
        for i in range(iter_count):
            curr_w = fixed_model.get_weights()
            batch_corr_mat = sub_correct_mat[FIX_BATCH_SIZE*i: FIX_BATCH_SIZE*(i+1)]  # 20 samples in one batch
            adjust_w = batch_get_adjust_w(curr_w, batch_corr_mat, weights_list, adjustment_strategy)

            fixed_model.set_weights(adjust_w)
            x = int(count * sub_correct_mat.shape[0] + i + 1)
            y = int(LOOP_COUNT * sub_correct_mat.shape[0])

            _, curr_acc = fixed_model.evaluate(x_val, y_val)
            print('[iteration {}/{}] current val acc: {:.4f}'.format(x, y, curr_acc))
            if curr_acc > best_val_acc:
                best_val_acc = curr_acc
                fixed_model.save_weights(fixed_weights_path)

                checkpoint = ModelCheckpoint(fixed_weights_path, monitor='val_accuracy', verbose=1, save_best_only=True,
                                             mode='max')
                checkpoint.best = best_val_acc
                hist = fixed_model.fit_generator(datagen.flow(x_train_val, y_train_val, batch_size=BATCH_SIZE),
                                                 steps_per_epoch=len(x_train_val) // BATCH_SIZE + 1,
                                                 validation_data=(x_val, y_val),
                                                 epochs=3,  # 3 epochs
                                                 callbacks=[checkpoint])
                fixed_model.load_weights(fixed_weights_path)

                temp_val_acc = np.max(np.array(hist.history['val_accuracy']))
                temp_train_acc = np.max(np.array(hist.history['accuracy']))
                if temp_val_acc > best_val_acc:
                    # val acc improved.
                    best_val_acc = temp_val_acc
                    best_train_acc = temp_train_acc
                _, best_test_acc = fixed_model.evaluate(x_test, y_test)
                # print(best_train_acc, best_val_acc)
                logger('Improved. Train acc: {:.4f}, val acc: {:.4f}, test acc: {:.4f}'.format(best_train_acc,
                                                                                               best_val_acc,
                                                                                               best_test_acc), log_path)
            else:
                fixed_model.load_weights(fixed_weights_path)
    end = datetime.now()
    logger('Spend time: {}'.format(end - start), log_path)
Example #7
0
def apricot5(model,
             model_weights_dir,
             dataset,
             adjustment_strategy,
             activation='binary'):
    """
    including Apricot and Apricot lite
    input:
        * dataset: [x_train_val, y_train_val, x_val, y_val, x_test, y_test]
    """
    # package the dataset
    x_train, x_test, y_train, y_test = load_dataset(dataset)
    x_train_val, x_val, y_train_val, y_val = split_validation_dataset(
        x_train, y_train)

    # x_train_val = np.concatenate((x_train_val, x_val), axis=0)
    # y_train_val = np.concatenate((y_train_val, y_val), axis=0)
    # print(x_train_val.shape, type(x_train_val))
    # print(y_train_val.shape, type(y_train_val))
    # return

    fixed_model = model

    submodel_dir = os.path.join(model_weights_dir, 'submodels')
    trained_weights_path = os.path.join(model_weights_dir, 'trained.h5')
    fixed_weights_path = os.path.join(
        model_weights_dir,
        'compare_fixed_{}_{}.h5'.format(adjustment_strategy, activation))
    log_path = os.path.join(model_weights_dir,
                            'compare_log_{}.txt'.format(adjustment_strategy))

    if not os.path.exists(fixed_weights_path):
        fixed_model.save_weights(fixed_weights_path)

    datagen = ImageDataGenerator(horizontal_flip=True,
                                 width_shift_range=0.125,
                                 height_shift_range=0.125,
                                 fill_mode='constant',
                                 cval=0.)

    datagen.fit(x_train)

    logger('----------original model----------', log_path)

    # submodels
    _, base_train_acc = fixed_model.evaluate(x_train_val, y_train_val)
    logger('The train accuracy: {:.4f}'.format(base_train_acc), log_path)
    _, base_val_acc = fixed_model.evaluate(x_val, y_val)
    # print('The validation accuracy: {:.4f}'.format(base_val_acc))
    logger('The validation accuracy: {:.4f}'.format(base_val_acc), log_path)
    _, base_test_acc = fixed_model.evaluate(x_test, y_test)
    # print('The test accuracy: {:.4f}'.format(base_test_acc))
    logger('The test accuracy: {:.4f}'.format(base_test_acc), log_path)

    best_weights = fixed_model.get_weights()
    best_acc = base_val_acc

    # find all indices of xs that original model fails on them.
    # fail_xs, fail_ys, fail_ys_label, fail_num = get_failing_cases(fixed_model, x_train_val, y_train_val)
    fail_xs, fail_ys, fail_ys_label, fail_num = get_failing_cases(
        fixed_model, x_train, y_train)  # use the whole training dataset

    if settings.NUM_SUBMODELS == 20:
        sub_correct_matrix_path = os.path.join(
            model_weights_dir,
            'corr_matrix_{}_{}.npy'.format(settings.RANDOM_SEED,
                                           settings.NUM_SUBMODELS))
    else:
        sub_correct_matrix_path = os.path.join(
            model_weights_dir,
            'corr_matrix_{}_{}.npy'.format(settings.RANDOM_SEED,
                                           settings.NUM_SUBMODELS))
    sub_correct_matrix = None  # 1: predicts correctly, -1: predicts incorrectly.
    print('obtaining sub correct matrix...')

    if not os.path.exists(sub_correct_matrix_path):
        # obtain submodel correctness matrix
        sub_correct_matrix = cal_sub_corr_matrix(fixed_model,
                                                 sub_correct_matrix_path,
                                                 submodel_dir,
                                                 fail_xs,
                                                 fail_ys,
                                                 fail_ys_label,
                                                 fail_num,
                                                 num_submodels=20)
    else:
        sub_correct_matrix = np.load(sub_correct_matrix_path)

    # generate random matrix for comparison.
    # sub_correct_matrix = np.random.randint(0,2, sub_correct_matrix.shape)
    # sub_correct_matrix[sub_correct_matrix == 0] = -1
    sub_correct_matrix = np.ones(sub_correct_matrix.shape)
    sub_correct_matrix = sub_correct_matrix * -1

    sub_weights_list = get_submodels_weights(fixed_model, submodel_dir)
    print('collected.')
    fixed_model.load_weights(trained_weights_path)
    # print(sub_correct_matrix.shape)
    # print(sub_correct_matrix[0:20, :])

    # print('start fixing process...')
    logger('----------start fixing process----------', log_path)
    logger(
        'number of cases to be adjusted: {}'.format(
            sub_correct_matrix.shape[0]), log_path)
    for _ in range(settings.LOOP_COUNT):
        np.random.shuffle(sub_correct_matrix)

        # load batches rather than single input.
        iter_num, rest = divmod(sub_correct_matrix.shape[0],
                                settings.FIX_BATCH_SIZE)
        if rest != 0:
            iter_num += 1

        print('iter num: {}'.format(iter_num))
        # batch version
        for i in range(iter_num):
            curr_weights = fixed_model.get_weights()
            batch_corr_matrix = sub_correct_matrix[settings.FIX_BATCH_SIZE *
                                                   i:settings.FIX_BATCH_SIZE *
                                                   (i + 1), :]
            # print('---------------------------------')
            # print(batch_corr_matrix)
            # print('---------------------------------')
            corr_w, incorr_w = batch_get_adjustment_weights(
                batch_corr_matrix, sub_weights_list, adjustment_strategy,
                curr_weights)
            # print(len(corr_w),len(incorr_w))
            print('calculating batch adjust weights...')
            # adjust_w = None
            # print(adjust_w)
            adjust_w = batch_adjust_weights_func(curr_weights,
                                                 corr_w,
                                                 incorr_w,
                                                 adjustment_strategy,
                                                 activation=activation)
            # print(curr_weights[0][0])
            # print('----------')
            # print(adjust_w[0][0])
            fixed_model.set_weights(adjust_w)

            _, curr_acc = fixed_model.evaluate(x_val, y_val)
            print('After adjustment, the validation accuracy: {:.4f}'.format(
                curr_acc))

            if curr_acc > best_acc:
                best_acc = curr_acc
                fixed_model.save_weights(fixed_weights_path)

                if adjustment_strategy <= 3:
                    # further training epochs.
                    checkpoint = ModelCheckpoint(fixed_weights_path,
                                                 monitor='val_accuracy',
                                                 verbose=1,
                                                 save_best_only=True,
                                                 mode='max')
                    checkpoint.best = best_acc
                    hist = fixed_model.fit_generator(
                        datagen.flow(x_train_val,
                                     y_train_val,
                                     batch_size=settings.BATCH_SIZE),
                        steps_per_epoch=len(x_train_val) // BATCH_SIZE + 1,
                        validation_data=(x_val, y_val),
                        epochs=settings.FURTHER_ADJUSTMENT_EPOCHS,
                        callbacks=[checkpoint])

                    # for key in hist.history:
                    #     print(key)

                    fixed_model.load_weights(fixed_weights_path)

                    # eval the model
                    _, val_acc = fixed_model.evaluate(x_val, y_val, verbose=0)
                    # _, test_acc = fixed_model.evaluate(x_test, y_test, verbose=0)
                    best_acc = val_acc

                    # print('validation accuracy after further training: {:.4f}'.format(test_acc))
                    logger(
                        'validation accuracy improved, after further training: {:.4f}'
                        .format(val_acc), log_path)
                else:
                    logger(
                        'validation accuracy improved: {:.4f}'.format(
                            best_acc), log_path)
            else:
                fixed_model.load_weights(fixed_weights_path)
                # pass

    fixed_model.load_weights(fixed_weights_path)
    if adjustment_strategy > 3:
        # final training process.
        _, val_acc = fixed_model.evaluate(x_val, y_val)
        checkpoint = ModelCheckpoint(fixed_weights_path,
                                     monitor='val_accuracy',
                                     verbose=1,
                                     save_best_only=True,
                                     mode='max')
        checkpoint.best = val_acc

        fixed_model.fit_generator(
            datagen.flow(x_train_val,
                         y_train_val,
                         batch_size=settings.BATCH_SIZE),
            steps_per_epoch=len(x_train_val) // BATCH_SIZE + 1,
            validation_data=(x_val, y_val),
            epochs=20,
            callbacks=[checkpoint])
        fixed_model.load_weights(fixed_weights_path)

    # final evaluation.
    _, test_acc = fixed_model.evaluate(x_test, y_test, verbose=0)
    logger('----------final evaluation----------', log_path)
    logger('test accuracy: {:.4f}'.format(test_acc), log_path)
Example #8
0
def apricot2(model,
             model_weights_dir,
             dataset,
             adjustment_strategy,
             activation='binary'):
    """
    including Apricot and Apricot lite
    input:
        * dataset: [x_train_val, y_train_val, x_val, y_val, x_test, y_test]
    """
    # package the dataset
    x_train, x_test, y_train, y_test = load_dataset(dataset)
    x_train_val, x_val, y_train_val, y_val = split_validation_dataset(
        x_train, y_train)

    fixed_model = model

    submodel_dir = os.path.join(model_weights_dir, 'submodels')
    trained_weights_path = os.path.join(model_weights_dir, 'trained.h5')
    fixed_weights_path = os.path.join(
        model_weights_dir, 'fixed_{}_{}.h5'.format(adjustment_strategy,
                                                   activation))
    log_path = os.path.join(model_weights_dir,
                            'log_{}.txt'.format(adjustment_strategy))

    if not os.path.exists(fixed_weights_path):
        fixed_model.save_weights(fixed_weights_path)

    datagen = ImageDataGenerator(horizontal_flip=True,
                                 width_shift_range=0.125,
                                 height_shift_range=0.125,
                                 fill_mode='constant',
                                 cval=0.)

    datagen.fit(x_train_val)

    logger('----------original model----------', log_path)

    # submodels
    _, base_val_acc = fixed_model.evaluate(x_val, y_val)
    # print('The validation accuracy: {:.4f}'.format(base_val_acc))
    logger('The validation accuracy: {:.4f}'.format(base_val_acc), log_path)
    _, base_test_acc = fixed_model.evaluate(x_test, y_test)
    # print('The test accuracy: {:.4f}'.format(base_test_acc))
    logger('The test accuracy: {:.4f}'.format(base_test_acc), log_path)

    best_weights = fixed_model.get_weights()
    best_acc = base_val_acc

    # find all indices of xs that original model fails on them.
    fail_xs, fail_ys, fail_ys_label, fail_num = get_failing_cases(
        fixed_model, x_train_val, y_train_val)

    if settings.NUM_SUBMODELS == 20:
        sub_correct_matrix_path = os.path.join(
            model_weights_dir,
            'corr_matrix_{}_{}.npy'.format(settings.RANDOM_SEED,
                                           settings.NUM_SUBMODELS))
    else:
        sub_correct_matrix_path = os.path.join(
            model_weights_dir,
            'corr_matrix_{}_{}.npy'.format(settings.RANDOM_SEED,
                                           settings.NUM_SUBMODELS))
    sub_correct_matrix = None  # 1: predicts correctly, 0: predicts incorrectly.
    print('obtaining sub correct matrix...')

    if not os.path.exists(sub_correct_matrix_path):
        # obtain submodel correctness matrix
        sub_correct_matrix = cal_sub_corr_matrix(fixed_model,
                                                 sub_correct_matrix_path,
                                                 submodel_dir, fail_xs,
                                                 fail_ys, fail_ys_label,
                                                 fail_num)
    else:
        sub_correct_matrix = np.load(sub_correct_matrix_path)

    sub_weights_list = get_submodels_weights(fixed_model, submodel_dir)
    print('collected.')
    fixed_model.load_weights(trained_weights_path)

    # print('start fixing process...')
    logger('----------start fixing process----------', log_path)
    for _ in range(settings.LOOP_COUNT):
        np.random.shuffle(sub_correct_matrix)

        for index in range(sub_correct_matrix.shape[0]):
            curr_weights = fixed_model.get_weights()
            corr_mat = sub_correct_matrix[index, :]

            print('obtaining correct and incorrect weights...')
            if adjustment_strategy <= 3:
                corr_w, incorr_w = get_adjustment_weights(
                    corr_mat, sub_weights_list, adjustment_strategy)
                print('calculating adjust weights...')
                adjust_w = adjust_weights_func(curr_weights,
                                               corr_w,
                                               incorr_w,
                                               adjustment_strategy,
                                               activation=activation)
            else:  # lite version
                print('calculating adjust weights...')
                adjust_w = adjust_weights_func_lite(corr_mat, sub_weights_list,
                                                    curr_weights)

            if adjust_w == -1:
                continue
            fixed_model.set_weights(adjust_w)

            _, curr_acc = fixed_model.evaluate(x_val, y_val)
            print('After adjustment, the validation accuracy: {:.4f}'.format(
                curr_acc))

            if curr_acc > best_acc:
                best_acc = curr_acc
                fixed_model.save_weights(fixed_weights_path)

                if adjustment_strategy <= 3:
                    # further training epochs.
                    checkpoint = ModelCheckpoint(fixed_weights_path,
                                                 monitor='val_acc',
                                                 verbose=1,
                                                 save_best_only=True,
                                                 mode='max')
                    checkpoint.best = best_acc
                    fixed_model.fit_generator(
                        datagen.flow(x_train_val,
                                     y_train_val,
                                     batch_size=settings.BATCH_SIZE),
                        steps_per_epoch=len(x_train_val) // BATCH_SIZE + 1,
                        validation_data=(x_val, y_val),
                        epochs=settings.FURTHER_ADJUSTMENT_EPOCHS,
                        callbacks=[checkpoint])

                    fixed_model.load_weights(fixed_weights_path)

                    # eval the model
                    _, val_acc = fixed_model.evaluate(x_val, y_val, verbose=0)
                    # _, test_acc = fixed_model.evaluate(x_test, y_test, verbose=0)
                    best_acc = val_acc

                    # print('validation accuracy after further training: {:.4f}'.format(test_acc))
                    logger(
                        'validation accuracy improved, after further training: {:.4f}'
                        .format(val_acc), log_path)
                else:
                    logger(
                        'validation accuracy improved: {:.4f}'.format(
                            best_acc), log_path)
            else:
                fixed_model.load_weights(fixed_weights_path)

    fixed_model.load_weights(fixed_weights_path)
    if adjustment_strategy > 3:
        # final training process.
        _, val_acc = fixed_model.evaluate(x_val, y_val)
        checkpoint = ModelCheckpoint(fixed_weights_path,
                                     monitor='val_acc',
                                     verbose=1,
                                     save_best_only=True,
                                     mode='max')
        checkpoint.best = val_acc

        fixed_model.fit_generator(
            datagen.flow(x_train_val,
                         y_train_val,
                         batch_size=settings.BATCH_SIZE),
            steps_per_epoch=len(x_train_val) // BATCH_SIZE + 1,
            validation_data=(x_val, y_val),
            epochs=settings.FURTHER_ADJUSTMENT_EPOCHS,
            callbacks=[checkpoint])
        fixed_model.load_weights(fixed_weights_path)

    # final evaluation.
    _, test_acc = fixed_model.evaluate(x_test, y_test, verbose=0)
    logger('----------final evaluation----------', log_path)
    logger('test accuracy: {:.4f}'.format(test_acc), log_path)
Example #9
0
def apricot(model, model_weights_dir, dataset, adjustment_strategy):
    x_train, x_test, y_train, y_test = load_dataset(dataset)
    x_train_val, x_val, y_train_val, y_val = split_validation_dataset(
        x_train, y_train)

    train_size = len(x_train)
    val_size = len(x_train_val)
    test_size = len(x_test)

    fixed_model = model
    submodel_dir = os.path.join(model_weights_dir, 'submodels')
    trained_weights_path = os.path.join(model_weights_dir, 'trained.h5')
    fixed_weights_path = os.path.join(
        model_weights_dir, 'apricot_fixed_{}.h5'.format(adjustment_strategy))
    log_path = os.path.join(model_weights_dir,
                            'apricot_{}.log'.format(adjustment_strategy))

    if not os.path.exists(fixed_weights_path):
        fixed_model.save_weights(fixed_weights_path)

    datagen = ImageDataGenerator(horizontal_flip=True,
                                 width_shift_range=0.125,
                                 height_shift_range=0.125,
                                 fill_mode='constant',
                                 cval=0.)
    datagen.fit(x_train)

    fixed_model.load_weights(trained_weights_path)

    start = datetime.now()
    logger('---------------original model---------------', log_path)
    _, base_train_acc = fixed_model.evaluate(x_train_val, y_train_val)
    _, base_val_acc = fixed_model.evaluate(x_val, y_val)
    _, base_test_acc = fixed_model.evaluate(x_test, y_test)

    # base_train_acc, base_val_acc, base_test_acc = (1,1,1)  # test

    logger(
        'train acc: {:.4f}, val acc: {:.4f}, test acc: {:.4f}'.format(
            base_train_acc, base_val_acc, base_test_acc), log_path)

    # to simply the process, get the classification results of submodels first.
    # do not shuffle the training dataset.
    fail_xs, fail_ys, fail_ys_label, fail_num, fail_index = get_indexed_failing_cases(
        fixed_model, x_train, y_train)

    print('getting sub correct matrix...')
    sub_correct_matrix_path = os.path.join(
        model_weights_dir, 'corr_mat_{}.npy'.format(NUM_SUBMODELS))
    if not os.path.exists(sub_correct_matrix_path):
        print('generating matrix....')
        sub_correct_mat = apricot_cal_sub_corr_mat(fixed_model,
                                                   submodel_dir,
                                                   fail_xs,
                                                   fail_ys,
                                                   num_submodels=NUM_SUBMODELS)
        np.save(sub_correct_matrix_path, sub_correct_mat)
    else:
        print('loading matrix...')
        sub_correct_mat = np.load(sub_correct_matrix_path)

    # iterates all training dataset.
    iter_batch_size = 20  # TODO revise hard-coding
    iter_num, ret = divmod(train_size, iter_batch_size)
    fail_idx_seq = get_formatted_batch_sequence(
        fail_index, total_num=train_size)  # binary indicator

    if ret != 0:
        iter_num += 1

    # the main process
    train_total_index = [i for i in range(train_size)
                         ]  # initialize indices for all training samples.
    train_total_index = np.array(train_total_index)

    sub_weights_list = get_weights_list(fixed_model,
                                        submodel_dir,
                                        num_submodels=NUM_SUBMODELS)

    fixed_model.load_weights(trained_weights_path)  # load the trained model.
    best_weights = fixed_model.get_weights(
    )  # used for keeping the best weights of the model.
    best_train_acc = base_train_acc
    best_val_acc = base_val_acc
    best_test_acc = base_test_acc

    # if not os.path.exists(fixed_weights_path):
    fixed_model.save_weights(fixed_weights_path)

    # print('iteration: {}, number of failing cases: {}'.format(iter_num, len(fail_xs)))
    logger(
        'iteration: {}, number of failing cases: {}'.format(
            iter_num, len(fail_xs)), log_path)

    print('start the main iteration process...')
    for i in range(iter_num):  # iterates by batch.
        try:
            # check if the index is in the fail_index.
            temp_train_index = train_total_index[i * iter_batch_size:(i + 1) *
                                                 iter_batch_size]
            temp_fail_idx_seq = fail_idx_seq[
                temp_train_index]  # temp binary indicator
            adjust_w = fixed_model.get_weights()
            # retrieve sub_correct_mat
            if np.sum(temp_fail_idx_seq) == 0:  # no failing cases.
                continue
            else:
                # exists failing cases.
                # get the failing case index
                # print(np.nonzero(temp_fail_idx_seq)[0])
                temp_fail_idx = temp_train_index[np.nonzero(temp_fail_idx_seq)
                                                 [0]]
                print('[iteration {}/{}]'.format(i, iter_num),
                      temp_fail_idx)  # Absolute index in train dataset
                for idx in temp_fail_idx:
                    sub_correct_mat_idx = int(
                        np.sum(fail_idx_seq[:idx + 1]
                               ))  # mapping the total idx back to sub mat idx.
                    # print(sub_correct_mat_idx)
                    temp_sub_corr_mat = sub_correct_mat[sub_correct_mat_idx]

                    # adjust weights
                    corr_avg, incorr_avg = get_avg_weights(
                        temp_sub_corr_mat, weights_list=sub_weights_list)
                    adjust_w = get_adjust_weights(adjust_w, corr_avg,
                                                  incorr_avg,
                                                  adjustment_strategy)

                # evaluation.
                fixed_model.set_weights(adjust_w)
                _, curr_acc = fixed_model.evaluate(x_val, y_val)
                print(
                    'After adjustment, the val acc: {:.4f}, best val acc: {:.4f}'
                    .format(curr_acc, best_val_acc))

                if curr_acc > best_val_acc:
                    best_val_acc = curr_acc
                    fixed_model.save_weights(fixed_weights_path)
                    # further training process.
                    checkpoint = ModelCheckpoint(fixed_weights_path,
                                                 monitor='val_accuracy',
                                                 verbose=1,
                                                 save_best_only=True,
                                                 mode='max')
                    checkpoint.best = best_val_acc
                    hist = fixed_model.fit_generator(
                        datagen.flow(x_train_val,
                                     y_train_val,
                                     batch_size=BATCH_SIZE),
                        steps_per_epoch=len(x_train_val) // BATCH_SIZE + 1,
                        validation_data=(x_val, y_val),
                        epochs=20,  # 3 epochs
                        callbacks=[checkpoint])
                    fixed_model.load_weights(fixed_weights_path)

                    temp_val_acc = np.max(
                        np.array(hist.history['val_accuracy']))
                    temp_train_acc = np.max(np.array(hist.history['accuracy']))
                    if temp_val_acc > best_val_acc:
                        # val acc improved.
                        best_val_acc = temp_val_acc
                        best_train_acc = temp_train_acc
                    _, best_test_acc = fixed_model.evaluate(x_test, y_test)
                    # print(best_train_acc, best_val_acc)
                    logger(
                        'Improved. Train acc: {:.4f}, val acc: {:.4f}, test acc: {:.4f}'
                        .format(best_train_acc, best_val_acc,
                                best_test_acc), log_path)
                    # fixed_model.save_weights(fixed_weights_path)
                else:  # worse than the best, rollback to the best case.
                    fixed_model.load_weights(fixed_weights_path)
        except:
            continue

    end = datetime.now()
    logger('Spend time: {}'.format(end - start), log_path)
Example #10
0
def apricorn(model, model_weights_dir, dataset):
    """
    the basic idea of apricorn is to update rDLMs at the same time.
    """
    x_train, x_test, y_train, y_test = load_dataset(dataset)
    x_train_val, x_val, y_train_val, y_val = split_validation_dataset(
        x_train, y_train)

    train_size = len(x_train)
    val_size = len(x_train_val)
    test_size = len(x_test)

    fixed_model = model
    submodel_dir = os.path.join(model_weights_dir, 'submodels')
    trained_weights_path = os.path.join(model_weights_dir, 'trained.h5')
    fixed_weights_path = os.path.join(model_weights_dir, 'apricorn_fixed.h5')
    log_path = os.path.join(model_weights_dir, 'apricorn.log')

    datagen = ImageDataGenerator(horizontal_flip=True,
                                 width_shift_range=0.125,
                                 height_shift_range=0.125,
                                 fill_mode='constant',
                                 cval=0.)
    datagen.fit(x_train)

    fixed_model.load_weights(trained_weights_path)
    start = datetime.now()

    sep_num = 5
    sep_count = 0

    logger('---------------original model---------------', log_path)
    # region initialization
    _, base_train_acc = fixed_model.evaluate(x_train_val, y_train_val)
    _, base_val_acc = fixed_model.evaluate(x_val, y_val)
    _, base_test_acc = fixed_model.evaluate(x_test, y_test)

    logger(
        'train acc: {:.4f}, val acc: {:.4f}, test acc: {:.4f}'.format(
            base_train_acc, base_val_acc, base_test_acc), log_path)

    fail_xs, fail_ys, fail_ys_label, fail_num, fail_index, correct_xs, correct_ys, correct_ys_label, correct_num, correct_index = get_indexed_failing_cases(
        fixed_model, x_train, y_train)
    print('getting sub correct matrix...')
    sub_correct_matrix_path = os.path.join(
        model_weights_dir, 'corr_mat_{}.npy'.format(NUM_SUBMODELS))
    sub_correct_matrix_path_2 = os.path.join(
        model_weights_dir, 'corr_mat_{}_2.npy'.format(NUM_SUBMODELS))

    if not os.path.exists(sub_correct_matrix_path):
        print('generating matrix....')
        sub_correct_mat = apricot_cal_sub_corr_mat(fixed_model,
                                                   submodel_dir,
                                                   fail_xs,
                                                   fail_ys,
                                                   num_submodels=NUM_SUBMODELS)
        np.save(sub_correct_matrix_path, sub_correct_mat)
    else:
        print('loading matrix...')
        sub_correct_mat = np.load(sub_correct_matrix_path)

    if not os.path.exists(sub_correct_matrix_path_2):
        print('generating matrix....')
        sub_correct_mat_2 = apricot_cal_sub_corr_mat(
            fixed_model,
            submodel_dir,
            correct_xs,
            correct_ys,
            num_submodels=NUM_SUBMODELS)
        np.save(sub_correct_matrix_path_2, sub_correct_mat_2)
    else:
        print('loading matrix...')
        sub_correct_mat_2 = np.load(sub_correct_matrix_path_2)

    fixed_model.load_weights(trained_weights_path)

    weights_list = get_weights_list(fixed_model, submodel_dir, NUM_SUBMODELS)

    best_train_acc = base_train_acc
    best_val_acc = base_val_acc
    best_test_acc = base_test_acc

    # if not os.path.exists(fixed_weights_path):
    fixed_model.save_weights(fixed_weights_path)
    # endregion

    # reduce the sub_correct_mat
    sub_correct_mat, sorted_idx, select_num = reduce_sub_corr_mat(
        sub_correct_mat, rate=0.1)

    sub_correct_mat_2, sorted_idx, select_num = reduce_sub_corr_mat(
        sub_correct_mat_2, rate=0.1)

    origin_sub_correct_mat = copy.deepcopy(sub_correct_mat)

    # Apricorn: iterates all failing cases.
    FIX_BATCH_SIZE = 5
    update_all = False
    impr_count = 0
    start = datetime.now()

    # failed case
    for count in range(1):
        # for i in range()
        iter_count, res = divmod(sub_correct_mat.shape[0], FIX_BATCH_SIZE)
        if res != 0:
            iter_count += 1
        print(iter_count)
        for i in range(iter_count):
            fixed_model.load_weights(fixed_weights_path)
            curr_w = fixed_model.get_weights()
            batch_corr_mat = sub_correct_mat[FIX_BATCH_SIZE *
                                             i:FIX_BATCH_SIZE * (i + 1)]
            adjust_w, adj_index_list = apricorn_batch_adjust_w(
                curr_w, batch_corr_mat, weights_list)  # update in lite way.
            # adjust_w = batch_get_adjust_w(curr_w, batch_corr_mat, weights_list)  # update in plus way.

            fixed_model.set_weights(adjust_w)

            x = int(count * sub_correct_mat.shape[0] + i + 1)
            y = int(sub_correct_mat.shape[0])
            _, curr_acc = fixed_model.evaluate(x_val, y_val)
            print(
                '[iteration {}/{}] current val acc: {:.4f}, best val acc: {:.4f}'
                .format(x, y, curr_acc, best_val_acc))

            if curr_acc > best_val_acc:
                best_val_acc = curr_acc
                fixed_model.save_weights(fixed_weights_path)
                logger('Improved. val acc: {:.4f}'.format(best_val_acc),
                       log_path)

                sep_count += 1
                if sep_count <= sep_num:  # reduce the number of training.
                    # sep_count = 0
                    # train the fixed model.
                    checkpoint = ModelCheckpoint(fixed_weights_path,
                                                 monitor='val_accuracy',
                                                 verbose=1,
                                                 save_best_only=True,
                                                 mode='max')
                    checkpoint.best = best_val_acc
                    hist = fixed_model.fit_generator(
                        datagen.flow(x_train_val,
                                     y_train_val,
                                     batch_size=BATCH_SIZE),
                        steps_per_epoch=len(x_train_val) // BATCH_SIZE + 1,
                        validation_data=(x_val, y_val),
                        epochs=5,  # 3 epochs
                        callbacks=[checkpoint])
                    fixed_model.load_weights(fixed_weights_path)
                    temp_val_acc = np.max(
                        np.array(hist.history['val_accuracy']))
                else:
                    temp_val_acc = best_val_acc

                curr_w = fixed_model.get_weights()

                if temp_val_acc > best_val_acc:
                    # val acc improved.
                    best_val_acc = temp_val_acc
            else:
                impr_count += 1
                if impr_count == 10:
                    impr_count = 0
                    update_all = True
                fixed_model.load_weights(fixed_weights_path)

                # Apricorn: update weights list.
                # print('update weights list...')
                # # prepare the train
                # weights_list, sub_correct_mat = apricorn_update_weights_list(fixed_model, curr_w, batch_corr_mat, weights_list,
                #                                                              adj_index_list=adj_index_list,
                #                                                              datagen=datagen,
                #                                                              x_val=x_val,
                #                                                              y_val=y_val,
                #                                                              x_train_val=x_train_val,
                #                                                              y_train_val=y_train_val,
                #                                                              sub_correct_mat=sub_correct_mat,
                #                                                              fail_xs=fail_xs,
                #                                                              fail_ys=fail_ys,
                #                                                              index=sorted_idx,
                #                                                              num=select_num,
                #                                                              update_all=update_all)  # lr=0.01

            # else:
            #     impr_count += 1
            #     if impr_count == 10:
            #         impr_count = 0
            #         update_all = True
            #     fixed_model.load_weights(fixed_weights_path)

            # print('update weights list...')
            # # prepare the train
            # weights_list, sub_correct_mat = apricorn_update_weights_list(fixed_model, curr_w, batch_corr_mat,
            #                                                              weights_list,
            #                                                              adj_index_list=adj_index_list,
            #                                                              datagen=datagen,
            #                                                              x_val=x_val,
            #                                                              y_val=y_val,
            #                                                              x_train_val=x_train_val,
            #                                                              y_train_val=y_train_val,
            #                                                              sub_correct_mat=sub_correct_mat,
            #                                                              fail_xs=fail_xs,
            #                                                              fail_ys=fail_ys,
            #                                                              index=sorted_idx,
            #                                                              num=select_num)  # lr=0.01

        # sub_correct_mat = copy.deepcopy(origin_sub_correct_mat)
        # np.random.shuffle(sub_correct_mat)

    # correct case.
    for count in range(1):
        # for i in range()
        iter_count, res = divmod(sub_correct_mat_2.shape[0], FIX_BATCH_SIZE)
        if res != 0:
            iter_count += 1
        print(iter_count)
        for i in range(iter_count):
            fixed_model.load_weights(fixed_weights_path)
            curr_w = fixed_model.get_weights()
            batch_corr_mat = sub_correct_mat_2[FIX_BATCH_SIZE *
                                               i:FIX_BATCH_SIZE * (i + 1)]
            adjust_w, adj_index_list = apricorn_batch_adjust_w(
                curr_w, batch_corr_mat, weights_list)  # update in lite way.
            # adjust_w = batch_get_adjust_w(curr_w, batch_corr_mat, weights_list)  # update in plus way.

            fixed_model.set_weights(adjust_w)

            x = int(count * sub_correct_mat_2.shape[0] + i + 1)
            y = int(sub_correct_mat_2.shape[0])
            _, curr_acc = fixed_model.evaluate(x_val, y_val)
            print(
                '[iteration {}/{}] current val acc: {:.4f}, best val acc: {:.4f}'
                .format(x, y, curr_acc, best_val_acc))

            if curr_acc > best_val_acc:
                best_val_acc = curr_acc
                fixed_model.save_weights(fixed_weights_path)
                logger('Improved. val acc: {:.4f}'.format(best_val_acc),
                       log_path)

                sep_count += 1
                if sep_count <= sep_num:  # reduce the number of training.
                    # sep_count = 0
                    # train the fixed model.
                    checkpoint = ModelCheckpoint(fixed_weights_path,
                                                 monitor='val_accuracy',
                                                 verbose=1,
                                                 save_best_only=True,
                                                 mode='max')
                    checkpoint.best = best_val_acc
                    hist = fixed_model.fit_generator(
                        datagen.flow(x_train_val,
                                     y_train_val,
                                     batch_size=BATCH_SIZE),
                        steps_per_epoch=len(x_train_val) // BATCH_SIZE + 1,
                        validation_data=(x_val, y_val),
                        epochs=5,  # 3 epochs
                        callbacks=[checkpoint])
                    fixed_model.load_weights(fixed_weights_path)
                    temp_val_acc = np.max(
                        np.array(hist.history['val_accuracy']))
                else:
                    temp_val_acc = best_val_acc

                curr_w = fixed_model.get_weights()

                if temp_val_acc > best_val_acc:
                    # val acc improved.
                    best_val_acc = temp_val_acc
            else:
                impr_count += 1
                if impr_count == 10:
                    impr_count = 0
                    update_all = True
                fixed_model.load_weights(fixed_weights_path)

    end = datetime.now()
    logger('Spend time: {}'.format(end - start), log_path)
Example #11
0
def apricot_plus_lite(model,
                      model_name,
                      get_trained_weights,
                      x_train_val,
                      y_train_val,
                      x_val,
                      y_val,
                      x_test,
                      y_test,
                      adjustment_strategy,
                      activation='binary',
                      ver=1,
                      dataset='cifar10',
                      max_count=1,
                      loop_count=100000,
                      random_seed=42):
    weights_dir = os.path.join(WEIGHTS_DIR, 'CNN')
    weights_dir = os.path.join(weights_dir, model_name)
    weights_dir = os.path.join(weights_dir, '{}'.format(ver))

    # create the dir
    if not os.path.isdir(weights_dir):
        os.makedirs(weights_dir)

    if get_trained_weights:
        model.load_weights(os.path.join(weights_dir, 'trained.h5'))

    weights_after_dir = os.path.join(
        weights_dir, 'fixed_{}_{}.h5'.format(adjustment_strategy, activation))

    if not os.path.exists(weights_after_dir):
        model.save_weights(weights_after_dir)

    datagen = ImageDataGenerator(horizontal_flip=True,
                                 width_shift_range=0.125,
                                 height_shift_range=0.125,
                                 fill_mode='constant',
                                 cval=0.)

    datagen.fit(x_train_val)

    # build the fixed model.
    if dataset == 'cifar10':
        img_rows, img_cols = 32, 32
        img_channels = 3
        num_classes = 10
        top_k = 1
    elif dataset == 'cifar100':
        img_rows, img_cols = 32, 32
        img_channels = 3
        num_classes = 100
        top_k = 5
    else:
        pass

    input_tensor = Input(shape=(img_rows, img_cols, img_channels))

    if model_name == 'resnet20':
        fixed_model = build_resnet(img_rows,
                                   img_cols,
                                   img_channels,
                                   num_classes=num_classes,
                                   stack_n=3,
                                   k=top_k)
    elif model_name == 'resnet32':
        fixed_model = build_resnet(img_rows,
                                   img_cols,
                                   img_channels,
                                   num_classes=num_classes,
                                   stack_n=5,
                                   k=top_k)
    elif model_name == 'mobilenet':
        fixed_model = build_mobilenet(input_tensor,
                                      num_classses=num_classes,
                                      k=top_k)
    elif model_name == 'mobilenet_v2':
        fixed_model = build_mobilenet_v2(input_tensor,
                                         num_classses=num_classes,
                                         k=top_k)
    elif model_name == 'densenet':
        fixed_model = build_densenet(input_tensor,
                                     num_classses=num_classes,
                                     k=top_k)

    fixed_model.load_weights(os.path.join(weights_dir, 'trained.h5'))
    # fixed_model = copy.deepcopy(model)

    # evaluate the acc before fixing.
    print('----------origin model----------')
    if dataset in ['cifar100', 'imagenet']:
        _, acc_top_1_train, train_acc = fixed_model.evaluate(
            x_train_val, y_train_val)
        print(
            '[==log==] training acc. before fixing: top-1: {:.4f}, top-5: {:.4f}'
            .format(acc_top_1_train, train_acc))
        _, acc_top_1_val, origin_acc = fixed_model.evaluate(x_val, y_val)
        print(
            '[==log==] validation acc. before fixing: top-1: {:.4f}, top-5: {:.4f}'
            .format(acc_top_1_val, origin_acc))
        _, acc_top_1_test, test_acc = fixed_model.evaluate(x_test, y_test)
        print(
            '[==log==] test acc. before fixing: top-1: {:.4f}, top-5: {:.4f}'.
            format(acc_top_1_test, test_acc))
        logger(weights_dir, '========================')
        logger(
            weights_dir, 'model: {}, adjustment strategy: {}, ver: {}'.format(
                model_name, adjustment_strategy, ver))
        logger(
            weights_dir,
            'TOP-1: train acc.: {:4f}, val acc.: {:4f}, test acc.: {:4f}'.
            format(acc_top_1_train, acc_top_1_val, acc_top_1_test))
        logger(
            weights_dir,
            'TOP-5: train acc.: {:4f}, val acc.: {:4f}, test acc.: {:4f}'.
            format(train_acc, origin_acc, test_acc))

    else:
        _, origin_acc = fixed_model.evaluate(x_val, y_val)
        print('----------origin model----------')
        _, train_acc = fixed_model.evaluate(x_train_val, y_train_val)
        print(
            '[==log==] training acc. before fixing: {:.4f}'.format(train_acc))
        print('[==log==] validation acc. before fixing: {:.4f}'.format(
            origin_acc))
        _, test_acc = fixed_model.evaluate(x_test, y_test)
        print('[==log==] test acc. before fixing: {:.4f}'.format(test_acc))
        logger(weights_dir, '========================')
        logger(
            weights_dir, 'model: {}, adjustment strategy: {}, ver: {}'.format(
                model_name, adjustment_strategy, ver))
        logger(
            weights_dir,
            'train acc.: {:4f}, val acc.: {:4f}, test acc.: {:4f}'.format(
                train_acc, origin_acc, test_acc))

    # start time
    start_time = datetime.now()

    # start fixing
    best_weights = fixed_model.get_weights()
    best_acc = origin_acc

    # find all indices of xs that original model fails on them.
    y_preds = model.predict(x_train_val)
    y_pred_label = np.argmax(y_preds, axis=1)
    y_true = np.argmax(y_train_val, axis=1)

    index_diff = np.nonzero(y_pred_label - y_true)

    fail_xs = x_train_val[index_diff]
    fail_ys = y_train_val[index_diff]
    fail_ys_label = np.argmax(fail_ys, axis=1)
    fail_num = int(np.size(index_diff))

    sub_correct_matrix_path = os.path.join(
        weights_dir, 'corr_matrix_{}.npy'.format(random_seed))
    sub_correct_matrix = None  # 1: predicts correctly, 0: predicts incorrectly.

    sub_weights_list = None

    if not os.path.exists(sub_correct_matrix_path):
        # obtain submodel correctness matrix
        submodels_path = os.path.join(weights_dir, 'submodels')

        for root, dirs, files in os.walk(submodels_path):
            for f in files:
                temp_w_path = os.path.join(root, f)
                fixed_model.load_weights(temp_w_path)
                sub_y_pred = fixed_model.predict(fail_xs)

                # top-1 accuracy
                if not dataset in ['cifar100', 'imagenet']:
                    sub_col = np.argmax(sub_y_pred, axis=1) - fail_ys_label
                    sub_col[sub_col != 0] = 1
                # top-5 accuracy
                else:
                    sub_col = K.in_top_k(sub_y_pred, K.argmax(fail_ys,
                                                              axis=-1), 5)
                    sub_col = K.get_value(sub_col)
                    sub_col = sub_col.astype(int)
                    sub_col = np.ones(shape=sub_col.shape) - sub_col

                if sub_correct_matrix is None:
                    sub_correct_matrix = sub_col.reshape(fail_num, 1)
                else:
                    sub_correct_matrix = np.concatenate(
                        (sub_correct_matrix, sub_col.reshape(fail_num, 1)),
                        axis=1)

            sub_correct_matrix = np.ones(
                shape=sub_correct_matrix.shape
            ) - sub_correct_matrix  # here change 0 to 1 (for correctly predicted case)
            np.save(sub_correct_matrix_path, sub_correct_matrix)

        # for sub in submodels:
        #     sub_y_pred = sub.predict(fail_xs)
        #     sub_col = np.argmax(sub_y_pred, axis=1) - fail_ys_label
        #     sub_col[sub_col != 0] = 1
        #     if sub_correct_matrix is None:
        #         sub_correct_matrix = copy.deepcopy(sub_col.reshape(fail_num, 1))
        #     else:
        #         sub_correct_matrix = np.concatenate((sub_correct_matrix, sub_col.reshape(fail_num, 1)), axis=1)
        # sub_correct_matrix = np.ones(shape=sub_correct_matrix.shape) - sub_correct_matrix
        # np.save(sub_correct_matrix_path, sub_correct_matrix)
    else:
        sub_correct_matrix = np.load(sub_correct_matrix_path)
        # revision
        sub_weights_list = get_submodels_weights(
            fixed_model, model_name, dataset,
            os.path.join(weights_dir, 'submodels'))

    # main loop
    fixed_model.load_weights(weights_after_dir)

    logger(weights_dir, '-----------------')
    logger(weights_dir, 'adjustment strategy {}'.format(adjustment_strategy))
    logger(
        weights_dir,
        'LOOP_COUNT: {}, BATCH_SIZE: {}, learning_rate: {}'.format(
            loop_count, BATCH_SIZE, learning_rate))
    logger(
        weights_dir,
        'PRE_EPOCHS: {}, AFTER_EPOCHS: {}, SUB_EPOCHS: {}, MAX_COUNT: {}'.
        format(PRE_EPOCHS, AFTER_EPOCHS, SUB_EPOCHS, max_count))
    logger(weights_dir, '-----------------')

    for _ in range(loop_count):
        np.random.shuffle(sub_correct_matrix)
        iter_count = 0
        for index in range(sub_correct_matrix.shape[0]):

            if iter_count >= max_count:
                break

            curr_weights = fixed_model.get_weights()
            corr_mat = sub_correct_matrix[index, :]

            # lite version
            corr_w, incorr_w = get_adjustment_weights(corr_mat,
                                                      sub_weights_list,
                                                      adjustment_strategy)
            adjust_w = adjust_weights_func(curr_weights,
                                           corr_w,
                                           incorr_w,
                                           adjustment_strategy,
                                           activation=activation)

            if adjust_w == -1:
                continue

            fixed_model.set_weights(adjust_w)

            if not dataset in ['cifar100', 'imagenet']:
                _, curr_acc = fixed_model.evaluate(x_val, y_val, verbose=0)
            else:
                _, _, curr_acc = fixed_model.evaluate(x_val, y_val, verbose=0)
            print(
                'tried times: {}, validation accuracy after adjustment: {:.4f}'
                .format(index, curr_acc))
            if curr_acc > best_acc:
                best_acc = curr_acc
                fixed_model.save_weights(weights_after_dir)

                if adjustment_strategy <= 3:
                    # Apricot+ further training process
                    if not dataset in ['cifar100', 'imagenet']:
                        checkpoint = ModelCheckpoint(weights_after_dir,
                                                     monitor=MONITOR,
                                                     verbose=1,
                                                     save_best_only=True,
                                                     mode='max')
                    else:
                        checkpoint = ModelCheckpoint(weights_after_dir,
                                                     monitor='val_top_k_acc',
                                                     verbose=1,
                                                     save_best_only=True,
                                                     mode='max')

                    checkpoint.best = best_acc
                    fixed_model.fit_generator(
                        datagen.flow(x_train_val,
                                     y_train_val,
                                     batch_size=BATCH_SIZE),
                        steps_per_epoch=len(x_train_val) // BATCH_SIZE + 1,
                        validation_data=(x_val, y_val),
                        epochs=FURTHER_ADJUSTMENT_EPOCHS,
                        callbacks=[checkpoint])
                    fixed_model.load_weights(weights_after_dir)

                    if not dataset in ['cifar100', 'imagenet']:
                        _, val_acc = fixed_model.evaluate(x_val,
                                                          y_val,
                                                          verbose=0)
                        _, test_acc = fixed_model.evaluate(x_test,
                                                           y_test,
                                                           verbose=0)
                    else:
                        _, _, val_acc = fixed_model.evaluate(x_val,
                                                             y_val,
                                                             verbose=0)
                        _, _, test_acc = fixed_model.evaluate(x_test,
                                                              y_test,
                                                              verbose=0)

                    print('validation acc. after retraining: {:.4f}'.format(
                        val_acc))
                    print(
                        'test acc. after retraining: {:.4f}'.format(test_acc))
                    logger(
                        weights_dir,
                        'Improved, validation acc.: {:.4f}, test acc.:{:.4f}'.
                        format(val_acc, test_acc))

                else:
                    print('-----------------------------')
                    print('evaluate on test dataset.')
                    best_acc = curr_acc
                    best_weights = adjust_w
                    fixed_model.save_weights(weights_after_dir)
                    # evaluation
                    if not dataset in ['cifar100', 'imagenet']:
                        _, val_acc = fixed_model.evaluate(x_val,
                                                          y_val,
                                                          verbose=0)
                        _, test_acc = fixed_model.evaluate(x_test,
                                                           y_test,
                                                           verbose=0)
                    else:
                        _, _, val_acc = fixed_model.evaluate(x_val,
                                                             y_val,
                                                             verbose=0)
                        _, _, test_acc = fixed_model.evaluate(x_test,
                                                              y_test,
                                                              verbose=0)

                    print('validation acc. after retraining: {:.4f}'.format(
                        val_acc))
                    print(
                        'test acc. after retraining: {:.4f}'.format(test_acc))
                    logger(
                        weights_dir,
                        'Improved, validation acc.: {:.4f}, test acc.:{:.4f}'.
                        format(val_acc, test_acc))

            else:
                fixed_model.set_weights(best_weights)

            iter_count += 1

    # further training process.
    if not dataset in ['cifar100', 'imagenet']:
        checkpoint = ModelCheckpoint(weights_after_dir,
                                     monitor=MONITOR,
                                     verbose=1,
                                     save_best_only=True,
                                     mode='max')
    else:
        checkpoint = ModelCheckpoint(weights_after_dir,
                                     monitor='val_top_k_acc',
                                     verbose=1,
                                     save_best_only=True,
                                     mode='max')

    checkpoint.best = best_acc
    fixed_model.fit_generator(datagen.flow(x_train_val,
                                           y_train_val,
                                           batch_size=BATCH_SIZE),
                              steps_per_epoch=len(x_train_val) // BATCH_SIZE +
                              1,
                              validation_data=(x_val, y_val),
                              epochs=FURTHER_ADJUSTMENT_EPOCHS,
                              callbacks=[checkpoint])

    # end time
    end_time = datetime.now()
    time_delta = end_time - start_time
    print('time used for adaptation: {}'.format(str(time_delta)))
    logger(weights_dir, 'time used for adaptation: {}'.format(str(time_delta)))

    fixed_model.load_weights(weights_after_dir)
    best_weights = fixed_model.get_weights()

    if dataset in ['cifar100', 'imagenet']:
        _, acc_top_1_train, train_acc = fixed_model.evaluate(
            x_train_val, y_train_val)
        _, acc_top_1_val, origin_acc = fixed_model.evaluate(x_val, y_val)
        _, acc_top_1_test, test_acc = fixed_model.evaluate(x_test, y_test)

        print(
            'after adjustment and retraining, TOP-1 train acc.: {}, val acc.: {}, test acc.: {}'
            .format(acc_top_1_train, acc_top_1_val, acc_top_1_test))
        print(
            'after adjustment and retraining, TOP-5 train acc.: {}, val acc.: {}, test acc.: {}'
            .format(train_acc, origin_acc, test_acc))

        logger(
            weights_dir,
            'after adjustment and retraining, TOP-1 train acc.: {}, val acc.: {}, test acc.: {}'
            .format(acc_top_1_train, acc_top_1_val, acc_top_1_test))
        logger(
            weights_dir,
            'after adjustment and retraining, TOP-5 train acc.: {}, val acc.: {}, test acc.: {}'
            .format(train_acc, origin_acc, test_acc))

    else:
        _, train_acc = fixed_model.evaluate(x_train_val,
                                            y_train_val,
                                            verbose=0)
        _, val_acc = fixed_model.evaluate(x_val, y_val, verbose=0)
        _, test_acc = fixed_model.evaluate(x_test, y_test, verbose=0)

        print('validation acc. after retraining: {:.4f}'.format(val_acc))
        print('test acc. after retraining: {:.4f}'.format(test_acc))

        logger(
            weights_dir,
            'after adjustment and retraining, train acc.: {}, val acc.: {}, test acc.: {}'
            .format(train_acc, val_acc, test_acc))