Exemplo n.º 1
0
def main():
    dataset = CIFAR10(
        binary=True, validation_split=0.0)  # not using validation for anything
    model = mobilenet_v2_like(dataset.input_shape, dataset.num_classes)

    model.compile(loss=SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=SGDW(lr=0.01, momentum=0.9, weight_decay=1e-5),
                  metrics=['accuracy'])
    model.summary()

    batch_size = 128

    train_data = dataset.train_dataset() \
        .shuffle(8 * batch_size) \
        .batch(batch_size) \
        .prefetch(tf.data.experimental.AUTOTUNE)
    valid_data = dataset.test_dataset() \
        .batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE)

    def lr_schedule(epoch):
        if 0 <= epoch < 35:
            return 0.01
        if 35 <= epoch < 65:
            return 0.005
        return 0.001

    model.fit(train_data,
              validation_data=valid_data,
              epochs=80,
              callbacks=[LearningRateScheduler(lr_schedule)])
    model.save("cnn-cifar10-binary.h5")
    def test_save_model_and_load_model_tf_optimizer(self):
        m1 = fe.build(
            fe.architecture.tensorflow.LeNet,
            optimizer_fn=lambda: SGDW(weight_decay=2e-5, learning_rate=2e-4))
        temp_folder = tempfile.mkdtemp()
        fe.backend.save_model(m1,
                              save_dir=temp_folder,
                              model_name="test",
                              save_optimizer=True)

        m2 = fe.build(
            fe.architecture.tensorflow.LeNet,
            optimizer_fn=lambda: SGDW(weight_decay=1e-5, learning_rate=1e-4))
        fe.backend.load_model(m2,
                              weights_path=os.path.join(
                                  temp_folder, "test.h5"),
                              load_optimizer=True)
        self.assertTrue(np.allclose(fe.backend.get_lr(model=m2), 2e-4))
        self.assertTrue(
            np.allclose(
                tf.keras.backend.get_value(m2.current_optimizer.weight_decay),
                2e-5))
Exemplo n.º 3
0
search_algorithm = AgingEvoSearch


def lr_schedule(epoch):
    if 0 <= epoch < 25:
        return 0.01
    if 25 <= epoch < 35:
        return 0.005
    return 0.001


training_config = TrainingConfig(
    dataset=FashionMNIST(),
    batch_size=128,
    epochs=45,
    optimizer=lambda: SGDW(lr=0.01, momentum=0.9, weight_decay=1e-5),
    callbacks=lambda: [LearningRateScheduler(lr_schedule)]
)

search_config = AgingEvoConfig(
    search_space=CnnSearchSpace(dropout=0.15),
    checkpoint_dir="artifacts/cnn_fashion"
)

bound_config = BoundConfig(
    error_bound=0.10,
    peak_mem_bound=64000,
    model_size_bound=64000,
    mac_bound=30000000
)
Exemplo n.º 4
0
        steps = [(90 - initial_epoch) * steps_per_epoch]
        decay = [0.01, 0.001]
        lr_schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
            steps, [0.05 * d for d in decay])
        wd_schedule = tf.optimizers.schedules.PiecewiseConstantDecay(
            steps, [0.0001 * d for d in decay])
    else:
        lr_schedule = 0.05 * 0.001
        wd_schedule = 0.0001 * 0.001

    # Create and compile TF model
    strategy = tf.distribute.MirroredStrategy()
    with strategy.scope():
        tf_model = vgg16()
        optimizer = SGDW(learning_rate=lr_schedule,
                         momentum=momentum,
                         nesterov=True,
                         weight_decay=wd_schedule)
        tf_model.compile(optimizer=optimizer,
                         loss='sparse_categorical_crossentropy',
                         metrics=['accuracy'])

    if args.reuse_tf_model:
        # Load old weights
        tf_model.load_weights('vgg16_imagenet_tf_weights.h5')

    else:
        # Load newest checkpoint weights if present
        if newest_checkpoint_file is not None:
            print(
                f'Loading epoch {initial_epoch} from checkpoint {newest_checkpoint_file}'
            )
Exemplo n.º 5
0
def main(argv):
    del argv
    # path
    data_dir = os.path.join(BASE_DIR, 'dataset', FLAGS.dataset)
    exp_dir = os.path.join(data_dir, 'exp', FLAGS.exp_name)
    model_dir = os.path.join(exp_dir, 'ckpt')
    log_dir = exp_dir
    os.makedirs(model_dir, exist_ok=True)
    # os.makedirs(log_dir, exist_ok=True)
    model_path = os.path.join(model_dir, 'model-{epoch:04d}.ckpt.h5')

    # logging
    log_format = '%(asctime)s %(message)s'
    logging.basicConfig(stream=sys.stdout,
                        level=logging.DEBUG,
                        format=log_format,
                        datefmt='%m/%d %I:%M:%S %p')
    fh = logging.FileHandler(os.path.join(log_dir, 'log.txt'))
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)
    logging.info(
        '------------------------------experiment start------------------------------------'
    )

    for i in (
            'exp_name',
            'dataset',
            'model',
            'mode',
            'lr',
    ):
        logging.info(
            '%s: %s' %
            (i, FLAGS.get_flag_value(i, '########VALUE MISSED#########')))
    logging.info(FLAGS.flag_values_dict())

    # resume from checkpoint
    largest_epoch = 0
    if FLAGS.resume == 'ckpt':
        chkpts = tf.io.gfile.glob(model_dir + '/*.ckpt.h5')
        if len(chkpts):
            largest_epoch = sorted([int(i[-12:-8]) for i in chkpts],
                                   reverse=True)[0]
            print('resume from epoch', largest_epoch)
            weight_path = model_path.format(epoch=largest_epoch)
        else:
            weight_path = None
    elif len(FLAGS.resume):
        assert os.path.isfile(FLAGS.resume)
        weight_path = FLAGS.resume
    else:
        weight_path = None

    dataset = importlib.import_module(
        'dataset.%s.data_loader' %
        FLAGS.dataset).DataLoader(**FLAGS.flag_values_dict())
    strategy = tf.distribute.MirroredStrategy()
    with strategy.scope():
        model = globals()[FLAGS.model](**FLAGS.flag_values_dict())
        # model = alexnet()
        if FLAGS.resume and weight_path:
            logging.info('resume from previous ckp: %s' % largest_epoch)
            model.load_weights(weight_path)
        # model.layers[1].trainable = False
        loss = globals()[FLAGS.loss]
        model.compile(
            optimizer=SGDW(momentum=0.9,
                           learning_rate=FLAGS.lr,
                           weight_decay=FLAGS.weight_decay),
            loss=loss,
            metrics=[
                "accuracy",
                Recall(),
                Precision(),
                MeanIoU(num_classes=FLAGS.classes)
            ],
        )
        # if 'train' in FLAGS.mode:
        #     model.summary()
        logging.info('There are %s layers in model' % len(model.layers))
        if FLAGS.freeze_layers > 0:
            logging.info('Freeze first %s layers' % FLAGS.freeze_layers)
            for i in model.layers[:FLAGS.freeze_layers]:
                i.trainable = False
        verbose = 1 if FLAGS.debug is True else 2
        if 'train' in FLAGS.mode:
            callbacks = [
                model_checkpoint(filepath=model_path,
                                 monitor=FLAGS.model_checkpoint_monitor),
                tensorboard(log_dir=os.path.join(exp_dir, 'tb-logs')),
                early_stopping(monitor=FLAGS.model_checkpoint_monitor,
                               patience=FLAGS.early_stopping_patience),
                lr_schedule(name=FLAGS.lr_schedule, epochs=FLAGS.epoch)
            ]
            file_writer = tf.summary.create_file_writer(
                os.path.join(exp_dir, 'tb-logs', "metrics"))
            file_writer.set_as_default()
            train_ds = dataset.get(
                'train')  # get first to calculate train size
            model.fit(
                train_ds,
                epochs=FLAGS.epoch,
                validation_data=dataset.get('valid'),
                callbacks=callbacks,
                initial_epoch=largest_epoch,
                verbose=verbose,
            )

            # evaluate before train on valid
            # result = model.evaluate(
            #     dataset.get('test'),
            # )
            # logging.info('evaluate before train on valid result:')
            # for i in range(len(result)):
            #     logging.info('%s:\t\t%s' % (model.metrics_names[i], result[i]))
        if 'test' in FLAGS.mode:
            # 学习valid
            # model.fit(
            #     dataset.get('valid'),
            #     epochs=3,
            #     # callbacks=callbacks,
            #     verbose=verbose
            # )
            # model.save_weights(os.path.join(model_dir, 'model.h5'))
            # 测试test
            result = model.evaluate(dataset.get('test'), )
            logging.info('evaluate result:')
            for i in range(len(result)):
                logging.info('%s:\t\t%s' % (model.metrics_names[i], result[i]))
            # TODO: remove previous checkpoint
        if 'predict' in FLAGS.mode:
            files = read_txt(
                os.path.join(BASE_DIR,
                             'dataset/%s/predict.txt' % FLAGS.dataset))
            output_dir = FLAGS.predict_output_dir
            os.makedirs(output_dir, exist_ok=True)
            i = 0
            ds = dataset.get('predict')
            for batch in ds:
                predict = model.predict(batch)
                for p in predict:
                    if i % 1000 == 0:
                        logging.info('curr: %s/%s' % (i, len(files)))
                    p_r = tf.squeeze(tf.argmax(
                        p, axis=-1)).numpy().astype('int16')
                    p_r = (p_r + 1) * 100
                    p_im = Image.fromarray(p_r)
                    im_path = os.path.join(
                        output_dir, '%s.png' % files[i].split('/')[-1][:-4])
                    p_im.save(im_path)
                    i += 1
        if FLAGS.task == 'visualize_result':
            dataset.visualize_evaluate(model, FLAGS.mode)