else: savePath = os.path.join(dirPath, "model_epoch_{}.pth".format(state['epoch'])) th.save(state, savePath) print("===> Checkpoint saved to {}".format(savePath)) tic() epoch_loss = float('inf') for epoch in range(start + 1, opt.nEpochs + 1): if opt.saveBest: epoch_loss_new = train(epoch) else: train(epoch) test() if not epoch % opt.saveFreq: if opt.saveBest and epoch_loss_new < epoch_loss: epoch_loss = epoch_loss_new state = {'epoch':epoch, 'model_state_dict':model.state_dict(),\ 'optimizer_state_dict':optimizer.state_dict(),\ 'rng_state':th.get_rng_state(),'params':params} save_checkpoint(state) elif not opt.saveBest: state = {'epoch':epoch, 'model_state_dict':model.state_dict(),\ 'optimizer_state_dict':optimizer.state_dict(),\ 'rng_state':th.get_rng_state(),'params':params} save_checkpoint(state) print("******************************************************") print( "\n ============ Training completed in {:.4f} seconds ======================\n" .format(toc()))
def copyModelParams(model, listModel): assert(model.stages == len(listModel)),"stages mismatch between the model "\ +" and the listModel." mkeys = list(model.state_dict().keys()) for stage in range(model.stages): skeys = [ key for key in mkeys if key.find('.' + str(stage) + '.') != -1 ] lkeys = list(listModel[stage].state_dict().keys()) for i, j in zip(skeys, lkeys): model.state_dict()[i].copy_(listModel[stage].state_dict()[j]) # Create a model of opt.stages that will constist of all the independently # trained stages params['stages'] = opt.stages model = UDNet(*params.values()) # Copy the parameters of each 1-stage network to the correct stage of the newly # created model. copyModelParams(model, Lmodel) # Save the final model state = {'params':params,\ 'model_state_dict':model.state_dict(),\ 'rng_state':th.get_rng_state()} savePath = os.path.join(dirPath, "model_final.pth") th.save(state, savePath) print("===> Final model saved to {}".format(savePath))