def test_preprocess_image():
     data = ImageDataLoader('.', transforms=[preprocess_image], p=4, k=10)
     data_iter=data.flow()
     batch=next(data_iter)
     img=batch[0][0]
     array_to_img(img).save('test.jpg')
Example #2
0
def train():
    if tf.__version__.split('.')[0] != '2':
        tf.enable_eager_execution()

    transform = [preprocess_image]

    p = 4
    k = 16
    batch_size = p * k
    learning_rate = 0.05
    epochs = 120
    img_shape = (256, 256)
    margin = 0.3

    data = ImageDataLoader('.', transforms=transform, p=p, k=k)

    class_num = len(data.classes)
    steps_per_epoch = (class_num // p) * 40

    strategy = tf.distribute.MirroredStrategy()
    print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
    global_batch_size = batch_size * strategy.num_replicas_in_sync
    num_replicas_in_sync = strategy.num_replicas_in_sync

    batch_hard_func = BatchHard(
        margin=margin, reduction=losses_utils.ReductionV2.NONE)
    id_loss_func = SparseCategoricalCrossentropy(
        reduction=losses_utils.ReductionV2.NONE)

    id_loss_metrics = tf.keras.metrics.Mean()
    id_corrects = tf.keras.metrics.SparseCategoricalAccuracy()

    running_corrects = batch_hard_func.running_corrects
    running_margin = batch_hard_func.running_margin
    triple_loss_metrics = tf.keras.metrics.Mean()

    def loss_func(id_output, features, labels):
        triple_loss = tf.reduce_sum(batch_hard_func(labels, features)) / global_batch_size
        id_loss = tf.reduce_sum(id_loss_func(
            labels, id_output)) / global_batch_size
        id_loss_metrics.update_state(id_loss)
        triple_loss_metrics.update_state(triple_loss)
        return id_loss + triple_loss

    with strategy.scope():
        model = build_baseline_model(class_num, img_shape)

        finetune_weights = model.get_layer(name='resnet50').trainable_weights
        finetune_optimizer = SGD(
            learning_rate=learning_rate * 0.1, momentum=0.9, nesterov=True)

        train_weights = [
            w for w in model.trainable_weights if not w in finetune_weights]
        optimizer = SGD(learning_rate=learning_rate,
                        momentum=0.9, nesterov=True)

        all_weights = finetune_weights + train_weights

        # sgd = SGD(learning_rate=1)

        learning_rate_scheduler = LearningRateScheduler(
            [optimizer, finetune_optimizer])

        data_iter = data.flow()

        with open('checkpoint/model.json', 'w', encoding='utf-8') as fp:
            fp.write(model.to_json())

        def train_step(batch):
            imgs, labels = batch

            with tf.GradientTape(persistent=True) as tape:
                id_output, features = model(imgs)

                loss = loss_func(id_output, features, labels)
                # l2_loss = weight_decay * \
                #     tf.add_n([tf.nn.l2_loss(v)
                #               for v in model.trainable_weights])

            grads = tape.gradient(loss, all_weights)
            # l2_grads = tape.gradient(l2_loss, model.trainable_weights)

            finetune_grads = grads[:len(finetune_weights)]
            train_grads = grads[len(finetune_weights):]

            finetune_optimizer.apply_gradients(
                zip(finetune_grads, finetune_weights))
            optimizer.apply_gradients(zip(train_grads, train_weights))
            # sgd.apply_gradients(zip(l2_grads, model.trainable_weights))

            id_corrects.update_state(labels, id_output)

            return loss

        @tf.function
        def distributed_train_step(batch):
            per_replica_losses = strategy.experimental_run_v2(
                train_step, args=(batch,))
            loss = strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                                   axis=None)
            return loss

        # model.load_weights('checkpoint/30.h5')
        # learning_rate_scheduler([optimizer, finetune_optimizer], 20)

        with K.learning_phase_scope(1):
            history = defaultdict(list)

            for cur_epoch in range(1, epochs + 1):
                print('Epoch {}/{}'.format(cur_epoch, epochs))
                progbar = Progbar(steps_per_epoch)

                learning_rate_scheduler(cur_epoch)

                for i in range(steps_per_epoch):
                    batch = next(data_iter)
                    if len(batch[1]) != batch_size:
                        batch = next(data_iter)
                        assert len(batch[1]) == batch_size

                    loss = distributed_train_step(batch)

                    cur_data = [('loss', loss), ('id_acc', id_corrects.result())]

                    progbar.add(1, values=cur_data)

                print(
                    f'acc: {running_corrects.result()} margin: {running_margin.result()}')
                print(
                    f'id acc: {id_corrects.result()} id loss: {id_loss_metrics.result()}')
                print(
                    f'triple_loss: {triple_loss_metrics.result()}')
                running_corrects.reset_states()
                running_margin.reset_states()
                triple_loss_metrics.reset_states()
                id_corrects.reset_states()
                id_loss_metrics.reset_states()

                for key, val in cur_data:
                    history[key].append(float(val))

                with open('checkpoint/history.json', 'w') as fp:
                    json.dump(history, fp)

                if cur_epoch % 5 == 0:
                    model.save_weights(f'checkpoint/{cur_epoch}.h5')