def main(): dataset = CIFAR10( binary=True, validation_split=0.0) # not using validation for anything model = mobilenet_v2_like(dataset.input_shape, dataset.num_classes) model.compile(loss=SparseCategoricalCrossentropy(from_logits=True), optimizer=SGDW(lr=0.01, momentum=0.9, weight_decay=1e-5), metrics=['accuracy']) model.summary() batch_size = 128 train_data = dataset.train_dataset() \ .shuffle(8 * batch_size) \ .batch(batch_size) \ .prefetch(tf.data.experimental.AUTOTUNE) valid_data = dataset.test_dataset() \ .batch(batch_size).prefetch(tf.data.experimental.AUTOTUNE) def lr_schedule(epoch): if 0 <= epoch < 35: return 0.01 if 35 <= epoch < 65: return 0.005 return 0.001 model.fit(train_data, validation_data=valid_data, epochs=80, callbacks=[LearningRateScheduler(lr_schedule)]) model.save("cnn-cifar10-binary.h5")
def test_save_model_and_load_model_tf_optimizer(self): m1 = fe.build( fe.architecture.tensorflow.LeNet, optimizer_fn=lambda: SGDW(weight_decay=2e-5, learning_rate=2e-4)) temp_folder = tempfile.mkdtemp() fe.backend.save_model(m1, save_dir=temp_folder, model_name="test", save_optimizer=True) m2 = fe.build( fe.architecture.tensorflow.LeNet, optimizer_fn=lambda: SGDW(weight_decay=1e-5, learning_rate=1e-4)) fe.backend.load_model(m2, weights_path=os.path.join( temp_folder, "test.h5"), load_optimizer=True) self.assertTrue(np.allclose(fe.backend.get_lr(model=m2), 2e-4)) self.assertTrue( np.allclose( tf.keras.backend.get_value(m2.current_optimizer.weight_decay), 2e-5))
search_algorithm = AgingEvoSearch def lr_schedule(epoch): if 0 <= epoch < 25: return 0.01 if 25 <= epoch < 35: return 0.005 return 0.001 training_config = TrainingConfig( dataset=FashionMNIST(), batch_size=128, epochs=45, optimizer=lambda: SGDW(lr=0.01, momentum=0.9, weight_decay=1e-5), callbacks=lambda: [LearningRateScheduler(lr_schedule)] ) search_config = AgingEvoConfig( search_space=CnnSearchSpace(dropout=0.15), checkpoint_dir="artifacts/cnn_fashion" ) bound_config = BoundConfig( error_bound=0.10, peak_mem_bound=64000, model_size_bound=64000, mac_bound=30000000 )
steps = [(90 - initial_epoch) * steps_per_epoch] decay = [0.01, 0.001] lr_schedule = tf.optimizers.schedules.PiecewiseConstantDecay( steps, [0.05 * d for d in decay]) wd_schedule = tf.optimizers.schedules.PiecewiseConstantDecay( steps, [0.0001 * d for d in decay]) else: lr_schedule = 0.05 * 0.001 wd_schedule = 0.0001 * 0.001 # Create and compile TF model strategy = tf.distribute.MirroredStrategy() with strategy.scope(): tf_model = vgg16() optimizer = SGDW(learning_rate=lr_schedule, momentum=momentum, nesterov=True, weight_decay=wd_schedule) tf_model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy']) if args.reuse_tf_model: # Load old weights tf_model.load_weights('vgg16_imagenet_tf_weights.h5') else: # Load newest checkpoint weights if present if newest_checkpoint_file is not None: print( f'Loading epoch {initial_epoch} from checkpoint {newest_checkpoint_file}' )
def main(argv): del argv # path data_dir = os.path.join(BASE_DIR, 'dataset', FLAGS.dataset) exp_dir = os.path.join(data_dir, 'exp', FLAGS.exp_name) model_dir = os.path.join(exp_dir, 'ckpt') log_dir = exp_dir os.makedirs(model_dir, exist_ok=True) # os.makedirs(log_dir, exist_ok=True) model_path = os.path.join(model_dir, 'model-{epoch:04d}.ckpt.h5') # logging log_format = '%(asctime)s %(message)s' logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, format=log_format, datefmt='%m/%d %I:%M:%S %p') fh = logging.FileHandler(os.path.join(log_dir, 'log.txt')) fh.setFormatter(logging.Formatter(log_format)) logging.getLogger().addHandler(fh) logging.info( '------------------------------experiment start------------------------------------' ) for i in ( 'exp_name', 'dataset', 'model', 'mode', 'lr', ): logging.info( '%s: %s' % (i, FLAGS.get_flag_value(i, '########VALUE MISSED#########'))) logging.info(FLAGS.flag_values_dict()) # resume from checkpoint largest_epoch = 0 if FLAGS.resume == 'ckpt': chkpts = tf.io.gfile.glob(model_dir + '/*.ckpt.h5') if len(chkpts): largest_epoch = sorted([int(i[-12:-8]) for i in chkpts], reverse=True)[0] print('resume from epoch', largest_epoch) weight_path = model_path.format(epoch=largest_epoch) else: weight_path = None elif len(FLAGS.resume): assert os.path.isfile(FLAGS.resume) weight_path = FLAGS.resume else: weight_path = None dataset = importlib.import_module( 'dataset.%s.data_loader' % FLAGS.dataset).DataLoader(**FLAGS.flag_values_dict()) strategy = tf.distribute.MirroredStrategy() with strategy.scope(): model = globals()[FLAGS.model](**FLAGS.flag_values_dict()) # model = alexnet() if FLAGS.resume and weight_path: logging.info('resume from previous ckp: %s' % largest_epoch) model.load_weights(weight_path) # model.layers[1].trainable = False loss = globals()[FLAGS.loss] model.compile( optimizer=SGDW(momentum=0.9, learning_rate=FLAGS.lr, weight_decay=FLAGS.weight_decay), loss=loss, metrics=[ "accuracy", Recall(), Precision(), MeanIoU(num_classes=FLAGS.classes) ], ) # if 'train' in FLAGS.mode: # model.summary() logging.info('There are %s layers in model' % len(model.layers)) if FLAGS.freeze_layers > 0: logging.info('Freeze first %s layers' % FLAGS.freeze_layers) for i in model.layers[:FLAGS.freeze_layers]: i.trainable = False verbose = 1 if FLAGS.debug is True else 2 if 'train' in FLAGS.mode: callbacks = [ model_checkpoint(filepath=model_path, monitor=FLAGS.model_checkpoint_monitor), tensorboard(log_dir=os.path.join(exp_dir, 'tb-logs')), early_stopping(monitor=FLAGS.model_checkpoint_monitor, patience=FLAGS.early_stopping_patience), lr_schedule(name=FLAGS.lr_schedule, epochs=FLAGS.epoch) ] file_writer = tf.summary.create_file_writer( os.path.join(exp_dir, 'tb-logs', "metrics")) file_writer.set_as_default() train_ds = dataset.get( 'train') # get first to calculate train size model.fit( train_ds, epochs=FLAGS.epoch, validation_data=dataset.get('valid'), callbacks=callbacks, initial_epoch=largest_epoch, verbose=verbose, ) # evaluate before train on valid # result = model.evaluate( # dataset.get('test'), # ) # logging.info('evaluate before train on valid result:') # for i in range(len(result)): # logging.info('%s:\t\t%s' % (model.metrics_names[i], result[i])) if 'test' in FLAGS.mode: # 学习valid # model.fit( # dataset.get('valid'), # epochs=3, # # callbacks=callbacks, # verbose=verbose # ) # model.save_weights(os.path.join(model_dir, 'model.h5')) # 测试test result = model.evaluate(dataset.get('test'), ) logging.info('evaluate result:') for i in range(len(result)): logging.info('%s:\t\t%s' % (model.metrics_names[i], result[i])) # TODO: remove previous checkpoint if 'predict' in FLAGS.mode: files = read_txt( os.path.join(BASE_DIR, 'dataset/%s/predict.txt' % FLAGS.dataset)) output_dir = FLAGS.predict_output_dir os.makedirs(output_dir, exist_ok=True) i = 0 ds = dataset.get('predict') for batch in ds: predict = model.predict(batch) for p in predict: if i % 1000 == 0: logging.info('curr: %s/%s' % (i, len(files))) p_r = tf.squeeze(tf.argmax( p, axis=-1)).numpy().astype('int16') p_r = (p_r + 1) * 100 p_im = Image.fromarray(p_r) im_path = os.path.join( output_dir, '%s.png' % files[i].split('/')[-1][:-4]) p_im.save(im_path) i += 1 if FLAGS.task == 'visualize_result': dataset.visualize_evaluate(model, FLAGS.mode)