else:
        savePath = os.path.join(dirPath,
                                "model_epoch_{}.pth".format(state['epoch']))
    th.save(state, savePath)
    print("===> Checkpoint saved to {}".format(savePath))


tic()
epoch_loss = float('inf')
for epoch in range(start + 1, opt.nEpochs + 1):
    if opt.saveBest:
        epoch_loss_new = train(epoch)
    else:
        train(epoch)
    test()
    if not epoch % opt.saveFreq:
        if opt.saveBest and epoch_loss_new < epoch_loss:
            epoch_loss = epoch_loss_new
            state = {'epoch':epoch, 'model_state_dict':model.state_dict(),\
                 'optimizer_state_dict':optimizer.state_dict(),\
                 'rng_state':th.get_rng_state(),'params':params}
            save_checkpoint(state)
        elif not opt.saveBest:
            state = {'epoch':epoch, 'model_state_dict':model.state_dict(),\
                 'optimizer_state_dict':optimizer.state_dict(),\
                 'rng_state':th.get_rng_state(),'params':params}
            save_checkpoint(state)
    print("******************************************************")
print(
    "\n ============ Training completed in {:.4f} seconds ======================\n"
    .format(toc()))
def copyModelParams(model, listModel):
    assert(model.stages == len(listModel)),"stages mismatch between the model "\
    +" and the listModel."
    mkeys = list(model.state_dict().keys())

    for stage in range(model.stages):
        skeys = [
            key for key in mkeys if key.find('.' + str(stage) + '.') != -1
        ]
        lkeys = list(listModel[stage].state_dict().keys())
        for i, j in zip(skeys, lkeys):
            model.state_dict()[i].copy_(listModel[stage].state_dict()[j])


# Create a model of opt.stages that will constist of all the independently
# trained stages
params['stages'] = opt.stages
model = UDNet(*params.values())

# Copy the parameters of each 1-stage network to the correct stage of the newly
# created model.
copyModelParams(model, Lmodel)

# Save the final model
state = {'params':params,\
         'model_state_dict':model.state_dict(),\
         'rng_state':th.get_rng_state()}
savePath = os.path.join(dirPath, "model_final.pth")
th.save(state, savePath)
print("===> Final model saved to {}".format(savePath))