예제 #1
0
def torch_resume(snapshot_path, trainer):
    """Function to resume from snapshot for pytorch

    :param str snapshot_path: snapshot file path
    :param instance trainer: chainer trainer instance
    """
    # load snapshot
    snapshot_dict = torch.load(snapshot_path, map_location=lambda storage, loc: storage)

    # restore trainer states
    d = NpzDeserializer(snapshot_dict['trainer'])
    d.load(trainer)

    # restore model states
    if hasattr(trainer.updater.model, "model"):
        # (for TTS model)
        if hasattr(trainer.updater.model.model, "module"):
            trainer.updater.model.model.module.load_state_dict(snapshot_dict['model'])
        else:
            trainer.updater.model.model.load_state_dict(snapshot_dict['model'])
    else:
        # (for ASR model)
        if hasattr(trainer.updater.model, "module"):
            trainer.updater.model.module.load_state_dict(snapshot_dict['model'])
        else:
            trainer.updater.model.load_state_dict(snapshot_dict['model'])

    # retore optimizer states
    trainer.updater.get_optimizer('main').load_state_dict(snapshot_dict['optimizer'])

    # delete opened snapshot
    del snapshot_dict
def torch_resume(snapshot_path, trainer):
    """Resume from snapshot for pytorch.

    Args:
        snapshot_path (str): Snapshot file path.
        trainer (chainer.training.Trainer): Chainer's trainer instance.

    """
    # load snapshot
    snapshot_dict = torch.load(snapshot_path, map_location=lambda storage, loc: storage)

    # restore trainer states
    d = NpzDeserializer(snapshot_dict['trainer'])
    d.load(trainer)
    print('| resumed best value loss = {}'.format(snapshot_dict['trainer']['extension_triggers/snapshot_object/best_value']))

    # restore model states
    if hasattr(trainer.updater.model, "model"):
        # (for TTS model)
        if hasattr(trainer.updater.model.model, "module"):
            trainer.updater.model.model.module.load_state_dict(snapshot_dict['model'])
        else:
            trainer.updater.model.model.load_state_dict(snapshot_dict['model'])
    else:
        # (for ASR model)
        if hasattr(trainer.updater.model, "module"):
            trainer.updater.model.module.load_state_dict(snapshot_dict['model'])
        else:
            trainer.updater.model.load_state_dict(snapshot_dict['model'])

    # retore optimizer states
    trainer.updater.get_optimizer('main').load_state_dict(snapshot_dict['optimizer'])

    # delete opened snapshot
    del snapshot_dict
예제 #3
0
def torch_resume(snapshot_path,
                 trainer,
                 weight_sharing=False,
                 reinit_adv=False):
    """Function to resume from snapshot for pytorch

    :param str snapshot_path: snapshot file path
    :param instance trainer: chainer trainer instance
    """
    # load snapshot
    snapshot_dict = torch.load(snapshot_path,
                               map_location=lambda storage, loc: storage)

    # restore trainer states
    if not weight_sharing:
        d = NpzDeserializer(snapshot_dict['trainer'])
        d.load(trainer)

    # restore model states
    if hasattr(trainer.updater.model, "model"):
        # (for TTS model)
        if hasattr(trainer.updater.model.model, "module"):
            trainer.updater.model.model.module.load_state_dict(
                snapshot_dict['model'])
        else:
            trainer.updater.model.model.load_state_dict(snapshot_dict['model'])
    else:
        # (for ASR model)
        # HACK: remove this later
        #del snapshot_dict['model']['predictor.adv.output.bias']
        #del snapshot_dict['model']['predictor.adv.output.weight']
        # HACK: remove this later

        # reinitialize only the adversarial branch to move it away from a
        # possible local optima
        if reinit_adv:
            logging.info(
                "Removing the learnt weights of adversarial branch ...")
            for k in snapshot_dict['model'].keys():
                if k.startswith('predictor.adv'):
                    del snapshot_dict['model'][k]

        if hasattr(trainer.updater.model, "module"):
            trainer.updater.model.module.load_state_dict(
                snapshot_dict['model'], strict=not weight_sharing)
        else:
            trainer.updater.model.load_state_dict(snapshot_dict['model'],
                                                  strict=not weight_sharing)

    # retore optimizer states
    if not weight_sharing:
        trainer.updater.get_optimizer('main').load_state_dict(
            snapshot_dict['optimizer'])

    # delete opened snapshot
    del snapshot_dict
예제 #4
0
def torch_resume(snapshot_path, trainer):
    """Resume from snapshot for pytorch.

    Args:
        snapshot_path (str): Snapshot file path.
        trainer (chainer.training.Trainer): Chainer's trainer instance.

    """
    # load snapshot
    snapshot_dict = torch.load(snapshot_path,
                               map_location=lambda storage, loc: storage)

    # restore trainer states
    d = NpzDeserializer(snapshot_dict["trainer"])
    d.load(trainer)

    # restore model states
    if hasattr(trainer.updater.model, "model"):
        # (for TTS model)
        if hasattr(trainer.updater.model.model, "module"):
            trainer.updater.model.model.module.load_state_dict(
                snapshot_dict["model"])
        else:
            trainer.updater.model.model.load_state_dict(snapshot_dict["model"])
    else:
        # (for ASR model)
        if hasattr(trainer.updater.model, "module"):
            trainer.updater.model.module.load_state_dict(
                snapshot_dict["model"])
        else:
            trainer.updater.model.load_state_dict(snapshot_dict["model"])

    # retore optimizer states
    trainer.updater.get_optimizer("main").load_state_dict(
        snapshot_dict["optimizer"])

    # delete opened snapshot
    del snapshot_dict
예제 #5
0
def main():
    parser = argparse.ArgumentParser(description='chainer implementation of pix2pix')
    parser.add_argument('--gpu', '-g', type=int, default=-1,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--img', '-i', help='Input image')

    parser.add_argument('--out', '-o', default='result_dehighlight',
                        help='Directory to output the result')
    args = parser.parse_args()

    ENC_W = os.path.join(args.out, "enc_iter_2500000.npz")
    #DEC_W = "trained_model/dec_iter_176000.npz"
    # to avoid GitHub 100M limit, one .npz files are divided into two zip files.
    DEC_Ws = [os.path.join(args.out, "dec_iter_2500000.npz")]

    #shutil.copy("net.py", args.out)

    # Set up a neural network to train
    enc = Encoder(in_ch=3)
    dec = Decoder(out_ch=3)

    chainer.serializers.load_npz(ENC_W, enc)
    # to avoid GitHub 100M limit, 1 .npz file is devided into 2 files
    for npzfile in DEC_Ws:
        with np.load(npzfile) as f:
            d = NpzDeserializer(f, strict=False)
            d.load(dec)

    if args.gpu >= 0:
        chainer.cuda.get_device(args.gpu).use()  # Make a specified GPU current
        enc.to_gpu()  # Copy the model to the GPU
        dec.to_gpu()

    inimg = loadimg(args.img)
    ch, h, w = inimg.shape
    # add paddings so that input array has the size of mutiples of 256.
    in_ary = np.zeros((ch,math.ceil(h/256)*256, math.ceil(w/256)*256), dtype="f")
    in_ary[:,0:h,0:w] = inimg
    x_in = in_ary[np.newaxis,:] # to fit into the minibatch shape
    print(x_in.shape)
    # x_in as an input image
    x_in = chainer.Variable(x_in)
    if args.gpu >= 0:
        x_in.to_gpu()

    st = time.time()
    for i in tqdm(range(10)):
        z = enc(x_in)
        x_out = dec(z)
    ts = (time.time() - st)/10
    print("mean estimation time:{:.2f}".format(ts))
    with open(os.path.join(args.out, "time.txt"), "a") as f:
        f.write("gpu:{}, time:{:.4f}, FPS:{:.4f}\n".format(args.gpu, ts, 1/ts))

    if args.gpu >= 0:
        out_ary = x_out.data.get()[0]
    else:
        out_ary = x_out.data[0]
    #img_show = np.zeros((inimg.shape[0], inimg.shape[1], inimg.shape[2]*2))
    #img_show[:,:,:inimg.shape[2]] = inimg
    #img_show[:,:outimg.shape[1],inimg.shape[2]:inimg.shape[2]+outimg.shape[2]] = outimg
    outimg = out_ary[:,0:h,0:w] # trim paddings
    img_show = np.concatenate((inimg, outimg), axis=2)
    bgrpic = to_bgr(img_show).copy()
    cv2.putText(bgrpic,"input",(3,15),cv2.FONT_HERSHEY_DUPLEX, 0.5,(255,0,0))
    cv2.putText(bgrpic,"output",(w+3,15),cv2.FONT_HERSHEY_DUPLEX, 0.5,(255,0,0))
    cv2.imshow("result", bgrpic)
    cv2.waitKey(0)
    cv2.destroyAllWindows()