Exemplo n.º 1
0
def load_DNet_for_train(cfg, dict_DB):
    model = DNet(cfg=cfg)
    model.cuda()
    optimizer = torch.optim.Adam(params=model.parameters(),
                                 lr=cfg.lr,
                                 weight_decay=cfg.weight_decay)

    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer,
                                                     milestones=cfg.milestones,
                                                     gamma=cfg.gamma)

    if cfg.resume == False:
        checkpoint = torch.load(cfg.weight_dir + 'checkpoint_DNet_final')
        model.load_state_dict(checkpoint['model'], strict=False)
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer=optimizer,
            milestones=cfg.milestones,
            gamma=cfg.gamma,
            last_epoch=checkpoint['epoch'])
        dict_DB['epoch'] = checkpoint['epoch']
        dict_DB['val_result'] = checkpoint['val_result']

    loss_fn = Loss_Function()

    dict_DB['DNet'] = model
    dict_DB['optimizer'] = optimizer
    dict_DB['scheduler'] = scheduler
    dict_DB['loss_fn'] = loss_fn

    return dict_DB
Exemplo n.º 2
0
def load_DNet_for_test(cfg, dict_DB):

    if cfg.run_mode == 'test_paper':
        checkpoint = torch.load(cfg.paper_weight_dir + 'checkpoint_DNet_paper')
    else:
        # select ckpt from output_dir
        checkpoint = torch.load(cfg.weight_dir + '')
    model = DNet(cfg=cfg)
    model.load_state_dict(checkpoint['model'], strict=False)
    model.cuda()
    dict_DB['DNet'] = model

    return dict_DB