def get_model(): d121 = models.densenet121(pretrained=True) model = SentinelDenseNet(d121, get_data().targets) wrapper = Model(model) return wrapper
def load_model(): wrapper, db = get_model(), get_data() model = wrapper.model mapper = db.targets model.load_state_dict(torch.load(CP)) model.eval() return model, mapper
def load_model(metric): CP = 'single_d121_ur_phase3_bal_4_' + f'{metric}.pt' wrapper, db = get_model(metric), get_data(metric) model = wrapper.model mapper = db.targets model.load_state_dict(torch.load(CP)) model.eval() return model, mapper
def main(): wrapper = get_model() opt_fn = optim.Adam cycle_fn = partial(optim.lr_scheduler.CyclicLR, mode='triangular2', cycle_momentum=False) # step_fn = partial(optim.lr_scheduler.StepLR, step_size=1, gamma=0.1) learn = Learner(wrapper) # # IMG_SZ = 224, BS: 32 # db = get_data(img_sz=224, bs=64) # # Freeze: Features, 3 epochs # learn.wrapper.freeze_features() # print('\n\nPhase 1, 224, 64', learn.wrapper.grads) # learn.fit(2, 1e-2, db, opt_fn, cycle_fn) # learn.fit(2, 1e-3, db, opt_fn, cycle_fn) # torch.save(learn.wrapper.model.state_dict(), 'single_d121_ur_phase1_bal_4.pt') # # learn.wrapper.model.load_state_dict(torch.load('single_d121_ur_phase1_bal_4.pt')) # # Freeze: Partial(0.7), 3 epochs # learn.wrapper.partial_freeze_features(0.7) # print('\n\nPhase 2, 224, 64', learn.wrapper.grads) # learn.fit(2, 1e-4, db, opt_fn, cycle_fn) # torch.save(learn.wrapper.model.state_dict(), 'single_d121_ur_phase2_bal_4.pt') # # learn.wrapper.model.load_state_dict(torch.load('single_d121_ur_phase2.pt')) db = get_data(img_sz=224, bs=32) # Freee: None, 3 Epochs learn.wrapper.freeze_features(False) print('\n\nPhase 3, 224, 32', learn.wrapper.grads) learn.fit(2, 1e-4, db, opt_fn, cycle_fn) learn.fit(2, 1e-5, db, opt_fn, cycle_fn) torch.save(learn.wrapper.model.state_dict(), 'single_d121_ur_phase3_bal_4.pt') # # learn.wrapper.model.load_state_dict(torch.load('single_d121_ur_phase3.pt')) # IMG_SZ = 224, BS: 32 # db = get_data(img_sz=224, bs=32) # print('\n\nPhase 4, 224, 32', learn.wrapper.grads) # learn.fit(1, 0.001, db, opt_fn, cycle_fn) # learn.fit(3, 0.1, db, opt_fn, cycle_fn) # torch.save(learn.wrapper.model.state_dict(), 'single_d121_phase4.pt') # learn.wrapper.model.load_state_dict(torch.load('single_d121_phase4.pt')) # learn.wrapper.freeze_features(False) # # IMG_SZ = 299, BS: 32 # db = get_data(img_sz=299, bs=16) # print('\n\nPhase 4, 299, 16', learn.wrapper.grads) # learn.fit(2, 0.00001, db, opt_fn, cycle_fn) # torch.save(learn.wrapper.model.state_dict(), 'single_d121_ur_phase4.pt') db = get_data(img_sz=224, bs=64) learn.wrapper.model.load_state_dict( torch.load('single_d121_ur_phase3_bal_4.pt')) # Freee: None, 3 Epochs learn.wrapper.freeze_features(True) print('\n\nPhase 4, 224, 64', learn.wrapper.grads) learn.fit(2, 1e-5, db, opt_fn, cycle_fn) torch.save(learn.wrapper.model.state_dict(), 'single_d121_ur_phase3_bal_4_final.pt')