Exemplo n.º 1
0
    "train HW": TRAIN_HW,
    "minival_gt_stddevs": MINIVAL_GT_STDDEVS,
    "val_gt_stddevs": VAL_GT_STDDEVS,
    "train_gt_stddevs": TRAIN_GT_STDDEVS,
    "distillation_alpha": DISTILLATION_ALPHA,
    **SCHEDULER_HYPERPARS
}
HPARS_DICT = {str(k): str(v) for k, v in HPARS_DICT.items()}
tb_logger.add_hparams(HPARS_DICT, {})
txt_logger.info("HYPERPARAMETERS:\n{}".format(HPARS_DICT))

# INSTANTIATE OPTIMIZER
loss_fn = DistillationLoss()
# If stem is not trainable it already has torch.no_grad so opt won't train it
opt = get_sgd_optimizer(student.parameters(),
                        half_precision=HALF_PRECISION,
                        **OPT_INIT_PARAMS)
lr_scheduler = SgdrScheduler(opt.optimizer, **SCHEDULER_HYPERPARS)

# INSTANTIATE DATALOADERS
with open(MINIVAL_FILE, "r") as f:
    MINIVAL_IDS = [int(line.rstrip('.jpg\n')) for line in f]

IMG_NORMALIZE_TRANSFORM = torchvision.transforms.Compose([
    # jitter? to gray?
    torchvision.transforms.Normalize(mean=IMG_NORM_MEAN,
                                     std=IMG_NORM_STDDEV,
                                     inplace=True)
])

AUGMENTATION_TRANSFORM = SeededCompose([
tb_logger.add_hparams(HPARS_DICT, {})
txt_logger.info("HYPERPARAMETERS:\n{}".format(HPARS_DICT))


# INSTANTIATE OPTIMIZER
DET_POS_WEIGHT = 100  # 100 means that black happens 100 more times than white
det_loss_fn = DistillationBceLossKeypointMining(DET_POS_WEIGHT, DET_POS_WEIGHT, DEVICE)
# att_loss_fn = torch.nn.BCELoss(pos_weight=torch.ones(1)*7).to(DEVICE) THIS SHOULD BE THE LOSS TO USE BUT DOESNT HAVE POS_WEIGHT AND THE OTHER WORKS AMD THE GPU IS BLOATED, SO WE KEEP WITH LOGITS ATM ALTHOUGH WE PROVIDE SIGMOID.
att_loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=torch.ones(1)*7).to(DEVICE)
# If stem is not trainable it already has torch.no_grad so opt won't train it
params = (# list(student.mid_stem.parameters()) +
          list(student.att_lo.parameters()) +
          list(student.att_mid.parameters()) +
          list(student.att_hi.parameters()) +
          list(student.att_top.parameters()))
att_opt = get_sgd_optimizer(params, half_precision=HALF_PRECISION,
                            **OPT_INIT_PARAMS)
att_lr_scheduler = SgdrScheduler(att_opt.optimizer, **SCHEDULER_HYPERPARS)
params = (list(student.mid_stem.parameters()) +
          list(student.steps.parameters()) +
          list(student.alt_img_stem.parameters()))
det_opt = get_sgd_optimizer(params, half_precision=HALF_PRECISION,
                            **OPT_INIT_PARAMS)
det_lr_scheduler = SgdrScheduler(det_opt.optimizer, **SCHEDULER_HYPERPARS)


# INSTANTIATE DATALOADERS
with open(MINIVAL_FILE, "r") as f:
    MINIVAL_IDS = [int(line.rstrip('.jpg\n')) for line in f]

with open(EASY_VAL_SMALL_PATH, "r") as f:
    EASY_IDS = [int(line.rstrip('.jpg\n')) for line in f]
VAL_DIR = os.path.join(HOME, "datasets/coco/val2017")
MODEL_PATH = "models/pose_higher_hrnet_w48_640.pth.tar"
INPUT_SIZE = 640  # this is hardcoded to the architecture
VERBOSE = True
HALF_PRECISION = True  # this is hardcoded to the architecture
DEVICE = "cuda"
MINIVAL_SIZE = 500
BATCH_SIZE = 4
NUM_EPOCHS = 5
DUMMY_INPUT = torch.rand((1, 3, INPUT_SIZE, INPUT_SIZE)).to(DEVICE)

# THIS SNIPPET RUNS A TEST OPTIMIZATION
#
s1 = StudentLinear(MODEL_PATH, HALF_PRECISION, trainable_stem=False)
loss_fn = MaskedMseLoss()
opt = get_sgd_optimizer(s1.parameters(), half_precision=HALF_PRECISION)
lr_scheduler = SgdrScheduler(opt.optimizer,
                             max_lr=0.001,
                             min_lr=0.001,
                             period=200,
                             scale_max_lr=1.0,
                             scale_min_lr=1.0,
                             scale_period=1.0)
LOGGER, tb_log_dir = create_logger("test_log", "log", "valid")
TB_LOGGER = SummaryWriter(log_dir=os.path.join("tb_log", tb_log_dir))

dummy_data = torch.rand(100, BATCH_SIZE, 3, INPUT_SIZE, INPUT_SIZE).to(DEVICE)
dummy_targets = torch.rand(100, BATCH_SIZE, 1, 123, 123).to(DEVICE)
dummy_targets[:, :, :, 10:, :] = 0  # some dummy targeting
mask = torch.ones(BATCH_SIZE, 1, 123, 123).to(DEVICE)
mask[:, :, :10, :] = 0