def validate(strategy, cfg, model=None, split='val', clear_foot=False): cfg.DATASET.CACHE = False result_path = '{}/{}_{}.json'.format(cfg.MODEL.SAVE_DIR, cfg.MODEL.NAME, split) if split == 'val': with suppress_stdout(): coco = merge_coco_annotations(cfg.DATASET.ANNOT, split, clear_foot) if model is None: with strategy.scope(): model = tf.keras.models.load_model(osp.join( cfg.MODEL.SAVE_DIR, cfg.MODEL.NAME + '.h5'), compile=False) cfg.DATASET.OUTPUT_SHAPE = model.output_shape[1:] ds = load_tfds(cfg, split, det=cfg.VAL.DET, predict_kp=True, drop_remainder=cfg.VAL.DROP_REMAINDER) ds = strategy.experimental_distribute_dataset(ds) @tf.function def predict(imgs, flip=False): if flip: imgs = imgs[:, :, ::-1, :] return model(imgs, training=False) results = [] for count, batch in enumerate(ds): ids, imgs, _, Ms, scores = batch ids = np.concatenate(ids.values, axis=0) scores = np.concatenate(scores.values, axis=0) Ms = np.concatenate(Ms.values, axis=0) hms = strategy.run(predict, args=(imgs, )).values hms = np.array(np.concatenate(hms, axis=0), np.float32) if cfg.VAL.FLIP: flip_hms = strategy.run(predict, args=( imgs, True, )).values flip_hms = np.concatenate(flip_hms, axis=0) flip_hms = flip_hms[:, :, ::-1, :] tmp = flip_hms.copy() for i in range(len(cfg.DATASET.KP_FLIP)): flip_hms[:, :, :, i] = tmp[:, :, :, cfg.DATASET.KP_FLIP[i]] # shift to align features flip_hms[:, :, 1:, :] = flip_hms[:, :, 0:-1, :].copy() hms = (hms + flip_hms) / 2. preds = get_preds(hms, Ms, cfg.DATASET.INPUT_SHAPE, cfg.DATASET.OUTPUT_SHAPE) all_preds = np.zeros((preds.shape[0], 23, 3)) all_preds[:, :preds.shape[1], :] = preds preds = all_preds kp_scores = preds[:, :, -1].copy() # rescore rescored_score = np.zeros((len(kp_scores))) for i in range(len(kp_scores)): score_mask = kp_scores[i] > cfg.VAL.SCORE_THRESH if np.sum(score_mask) > 0: rescored_score[i] = np.mean( kp_scores[i][score_mask]) * scores[i] score_result = rescored_score for i in range(preds.shape[0]): results.append( dict(image_id=int(ids[i]), category_id=1, keypoints=preds[i].reshape(-1).tolist(), score=float(score_result[i]))) if cfg.TRAIN.DISP: print('completed preds batch', count + 1) with open(result_path, 'w') as f: json.dump(results, f) if split == 'val': with suppress_stdout(): result = coco.loadRes(result_path) cocoEval = COCOeval(coco, result, iouType='keypoints') cocoEval.evaluate() cocoEval.accumulate() cocoEval.summarize() mAP = cocoEval.stats[0] AP_50 = cocoEval.stats[1] AP_75 = cocoEval.stats[2] AP_small = cocoEval.stats[3] AP_medium = cocoEval.stats[4] AP_large = cocoEval.stats[5] return mAP, AP_50, AP_75, AP_small, AP_medium, AP_large # AP
def train(strategy, cfg): os.makedirs(cfg.MODEL.SAVE_DIR, exist_ok=True) if cfg.DATASET.BFLOAT16: policy = mixed_precision.Policy('mixed_bfloat16') mixed_precision.set_policy(policy) tf.random.set_seed(cfg.TRAIN.SEED) np.random.seed(cfg.TRAIN.SEED) spe = int(np.ceil(cfg.DATASET.TRAIN_SAMPLES / cfg.TRAIN.BATCH_SIZE)) spv = cfg.DATASET.VAL_SAMPLES // cfg.VAL.BATCH_SIZE if cfg.TRAIN.SCALE_LR: lr = cfg.TRAIN.BASE_LR * cfg.TRAIN.BATCH_SIZE / 32 cfg.TRAIN.WARMUP_FACTOR = 32 / cfg.TRAIN.BATCH_SIZE else: lr = cfg.TRAIN.BASE_LR if cfg.TRAIN.LR_SCHEDULE == 'warmup_cosine_decay': lr_schedule = WarmupCosineDecay( initial_learning_rate=lr, decay_steps=cfg.TRAIN.EPOCHS * spe, warmup_steps=cfg.TRAIN.WARMUP_EPOCHS * spe, warmup_factor=cfg.TRAIN.WARMUP_FACTOR) elif cfg.TRAIN.LR_SCHEDULE == 'warmup_piecewise': lr_schedule = WarmupPiecewise( boundaries=[x * spe for x in cfg.TRAIN.DECAY_EPOCHS], values=[lr, lr / 10, lr / 10 ** 2], warmup_steps=spe * cfg.TRAIN.WARMUP_EPOCHS, warmup_factor=cfg.TRAIN.WARMUP_FACTOR) else: lr_schedule = lr with strategy.scope(): optimizer = tf.keras.optimizers.Adam(lr_schedule) if cfg.TRAIN.WANDB_RUN_ID: api = wandb.Api() run = api.run(f"{cfg.EVAL.WANDB_RUNS}/{cfg.TRAIN.WANDB_RUN_ID}") run.file("model-best.h5").download(replace=True) model = tf.keras.models.load_model('model-best.h5', custom_objects={ 'relu6': tf.nn.relu6, 'WarmupCosineDecay': WarmupCosineDecay }) model.compile(optimizer=model.optimizer, loss=mse) else: if cfg.MODEL.TYPE == 'simple_baseline': model = SimpleBaseline(cfg) elif cfg.MODEL.TYPE == 'hrnet': model = HRNet(cfg) elif cfg.MODEL.TYPE == 'evopose': model = EvoPose(cfg) elif cfg.MODEL.TYPE == 'eflite': model = EfficientNetLite(cfg) elif cfg.MODEL.TYPE == 'ef': model = EfficientNet(cfg) model.compile(optimizer=optimizer, loss=mse) cfg.DATASET.OUTPUT_SHAPE = model.output_shape[1:] cfg.DATASET.SIGMA = 2 * cfg.DATASET.OUTPUT_SHAPE[0] / 64 wandb_config = setup_wandb(cfg, model) train_ds = load_tfds(cfg, 'train') train_ds = strategy.experimental_distribute_dataset(train_ds) if cfg.TRAIN.VAL: val_ds = load_tfds(cfg, 'val') val_ds = strategy.experimental_distribute_dataset(val_ds) print('Training {} ({} / {}) on {} for {} epochs' .format(cfg.MODEL.NAME, wandb_config.parameters, wandb_config.flops, cfg.TRAIN.ACCELERATOR, cfg.TRAIN.EPOCHS)) initial_epoch = 0 if cfg.TRAIN.WANDB_RUN_ID: initial_epoch = cfg.TRAIN.INITIAL_EPOCH model.fit(train_ds, initial_epoch=initial_epoch, epochs=cfg.TRAIN.EPOCHS, verbose=1, validation_data=val_ds, validation_steps=spv, steps_per_epoch=spe, callbacks=[WandbCallback()]) return model
import tensorflow as tf import dataset.plots as pl import dataset.dataloader as dl tf.random.set_seed(0) from dataset.coco import cn as cfg cfg.DATASET.INPUT_SHAPE = [512, 384, 3] cfg.DATASET.NORM = False cfg.DATASET.BGR = True cfg.DATASET.HALF_BODY_PROB = 1. ds = dl.load_tfds(cfg, 'val', det=False, predict_kp=True, drop_remainder=False, visualize=True) for i, (ids, imgs, kps, Ms, scores, hms, valids) in enumerate(ds): f = 18 * 3 - 1 for i in range(cfg.TRAIN.BATCH_SIZE): kp = kps[i] if np.sum(kp[:, 2][17:]) > 0: img = imgs[i] pl.plot_image(np.uint8(img), hms[i], kp[:, -1].numpy()) cv2.imshow( '', dl.visualize(np.uint8(img), kp[:, :2].numpy(), kp[:, -1].numpy())) cv2.waitKey() cv2.destroyAllWindows()
def train(strategy, cfg): os.makedirs(cfg.MODEL.SAVE_DIR, exist_ok=True) if cfg.DATASET.BFLOAT16: policy = mixed_precision.Policy('mixed_bfloat16') mixed_precision.set_policy(policy) tf.random.set_seed(cfg.TRAIN.SEED) np.random.seed(cfg.TRAIN.SEED) meta_data = {'train_loss': [], 'val_loss': [], 'config': cfg} spe = int(np.ceil(cfg.DATASET.TRAIN_SAMPLES / cfg.TRAIN.BATCH_SIZE)) spv = cfg.DATASET.VAL_SAMPLES // cfg.VAL.BATCH_SIZE if cfg.TRAIN.SCALE_LR: lr = cfg.TRAIN.BASE_LR * cfg.TRAIN.BATCH_SIZE / 32 cfg.TRAIN.WARMUP_FACTOR = 32 / cfg.TRAIN.BATCH_SIZE else: lr = cfg.TRAIN.BASE_LR if cfg.TRAIN.LR_SCHEDULE == 'warmup_cosine_decay': lr_schedule = WarmupCosineDecay(initial_learning_rate=lr, decay_steps=cfg.TRAIN.EPOCHS * spe, warmup_steps=cfg.TRAIN.WARMUP_EPOCHS * spe, warmup_factor=cfg.TRAIN.WARMUP_FACTOR) elif cfg.TRAIN.LR_SCHEDULE == 'warmup_piecewise': lr_schedule = WarmupPiecewise( boundaries=[x * spe for x in cfg.TRAIN.DECAY_EPOCHS], values=[lr, lr / 10, lr / 10**2], warmup_steps=spe * cfg.TRAIN.WARMUP_EPOCHS, warmup_factor=cfg.TRAIN.WARMUP_FACTOR) else: lr_schedule = lr with strategy.scope(): optimizer = tf.keras.optimizers.Adam(lr_schedule) if cfg.MODEL.TYPE == 'simple_baseline': model = SimpleBaseline(cfg) elif cfg.MODEL.TYPE == 'hrnet': model = HRNet(cfg) elif cfg.MODEL.TYPE == 'evopose': model = EvoPose(cfg) train_loss = tf.keras.metrics.Mean() val_loss = tf.keras.metrics.Mean() cfg.DATASET.OUTPUT_SHAPE = model.output_shape[1:] cfg.DATASET.SIGMA = 2 * cfg.DATASET.OUTPUT_SHAPE[0] / 64 meta_data['parameters'] = model.count_params() meta_data['flops'] = get_flops(model) train_ds = load_tfds(cfg, 'train') train_ds = strategy.experimental_distribute_dataset(train_ds) train_iterator = iter(train_ds) if cfg.TRAIN.VAL: val_ds = load_tfds(cfg, 'val') val_ds = strategy.experimental_distribute_dataset(val_ds) @tf.function def train_step(train_iterator): def step_fn(inputs): imgs, targets, valid = inputs with tf.GradientTape() as tape: loss, l2_loss = mse_loss(model, imgs, targets, valid, training=True) scaled_loss = (loss + l2_loss) / strategy.num_replicas_in_sync grads = tape.gradient(scaled_loss, model.trainable_variables) optimizer.apply_gradients( list(zip(grads, model.trainable_variables))) train_loss.update_state(loss) strategy.run(step_fn, args=(next(train_iterator), )) @tf.function def val_step(dist_inputs): def step_fn(inputs): imgs, targets, valid = inputs loss, _ = mse_loss(model, imgs, targets, valid, training=False) val_loss.update_state(loss) strategy.run(step_fn, args=(dist_inputs, )) print('Training {} ({:.2f}M / {:.2f}G) on {} for {} epochs'.format( cfg.MODEL.NAME, meta_data['parameters'] / 1e6, meta_data['flops'] / 2 / 1e9, cfg.TRAIN.ACCELERATOR, cfg.TRAIN.EPOCHS)) epoch = 1 ts = time() while epoch <= cfg.TRAIN.EPOCHS: te = time() for i in range(spe): train_step(train_iterator) if cfg.TRAIN.DISP: print('epoch {} ({}/{}) | loss: {:.1f}'.format( epoch, i + 1, spe, train_loss.result().numpy())) meta_data['train_loss'].append(train_loss.result().numpy()) if cfg.TRAIN.VAL: for i, batch in enumerate(val_ds): val_step(batch) if cfg.TRAIN.DISP: print('val {} ({}/{}) | loss: {:.1f}'.format( epoch, i + 1, spv, val_loss.result().numpy())) meta_data['val_loss'].append(val_loss.result().numpy()) if cfg.VAL.SAVE_BEST: if epoch == 1: best_weights = model.get_weights() best_loss = val_loss.result().numpy() if cfg.TRAIN.DISP: print('Cached model weights') elif val_loss.result().numpy() < best_loss: best_weights = model.get_weights() best_loss = val_loss.result().numpy() if cfg.TRAIN.DISP: print('Cached model weights') train_loss.reset_states() val_loss.reset_states() if cfg.TRAIN.SAVE_EPOCHS and epoch % cfg.TRAIN.SAVE_EPOCHS == 0: model.save(osp.join( cfg.MODEL.SAVE_DIR, '{}_ckpt{:03d}.h5'.format(cfg.MODEL.NAME, epoch)), save_format='h5') print( 'Saved checkpoint to', osp.join(cfg.MODEL.SAVE_DIR, '{}_ckpt{:03d}.h5'.format(cfg.MODEL.NAME, epoch))) if cfg.TRAIN.SAVE_META: pickle.dump( meta_data, open( osp.join(cfg.MODEL.SAVE_DIR, '{}_meta.pkl'.format(cfg.MODEL.NAME)), 'wb')) if epoch > 1 and cfg.TRAIN.DISP: est_time = (cfg.TRAIN.EPOCHS - epoch) * (time() - te) / 3600 print('Estimated time remaining: {:.2f} hrs'.format(est_time)) epoch += 1 meta_data['training_time'] = time() - ts if cfg.VAL.SAVE_BEST: model.set_weights(best_weights) return model, meta_data