Ejemplo n.º 1
0
def init(mdl_name=None, ckpt_name=None):
    if not mdl_name:
        mdl_name = config.DEFAULT_MODEL_NAME

    SAVE_PATH = os.path.join(DIR_PATH, config.SAVE_DIR, mdl_name)
    print('Saving path:', SAVE_PATH)

    ckpt_mng = CheckpointManager(SAVE_PATH)

    checkpoint, continue_training = None, False
    if ckpt_name:
        print('Load checkpoint:', ckpt_name)
        ckpt_tokons = ckpt_name.split('/')
        if len(ckpt_tokons) == 1:
            checkpoint = ckpt_mng.load(ckpt_tokons[0], device)
            continue_training = True

        elif len(ckpt_tokons) == 2:
            load_path = os.path.join(DIR_PATH, config.SAVE_DIR, ckpt_tokons[0])
            load_ckpt_mng = CheckpointManager(load_path)
            checkpoint = load_ckpt_mng.load(ckpt_tokons[1], device)
            continue_training = False

        else:
            raise Exception('Invalid checkpoint path:', ckpt_name)

    model_config = config.load(mdl_name)
    model = build_model(model_config, checkpoint)

    return model, {
        'voc': voc,
        'checkpoint': checkpoint if continue_training else None,
        'ckpt_mng': ckpt_mng,
        'model_config': model_config
    }
Ejemplo n.º 2
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-m',
                        '--mode',
                        choices={'pretrain', 'finetune', 'chat'},
                        help="mode to run the chatbot")
    parser.add_argument('-s', '--speaker', default='<none>')
    parser.add_argument('-cp', '--checkpoint')
    args = parser.parse_args()

    print('Saving path:', SAVE_PATH)
    checkpoint_mng = CheckpointManager(SAVE_PATH)

    checkpoint = None
    if args.checkpoint:
        print('Load checkpoint:', args.checkpoint)
        checkpoint = checkpoint_mng.load(args.checkpoint, device)

    model, voc, persons = build_model(checkpoint)

    if args.mode == 'pretrain' or args.mode == 'finetune':
        train(args.mode, model, voc, persons, checkpoint, checkpoint_mng)

    elif args.mode == 'chat':
        speaker_name = args.speaker
        if persons.has(speaker_name):
            print('Selected speaker:', speaker_name)
            speaker_id = persons.get_index(speaker_name)
            chat(model, voc, speaker_id)
        else:
            print('Invalid speaker. Possible speakers:', persons.tokens)
Ejemplo n.º 3
0
def init():
    parser = argparse.ArgumentParser()
    parser.add_argument('-cp', '--checkpoint')
    args = parser.parse_args()

    checkpoint_mng = CheckpointManager(SAVE_PATH)
    checkpoint = None if not args.checkpoint else checkpoint_mng.load(
        args.checkpoint, device)

    model, voc = build_model(checkpoint)
    # Set dropout layers to eval mode
    model.eval()
    # Initialize search module
    if config.BEAM_SEARCH_ON:
        searcher = BeamSearchDecoder(model)
    else:
        searcher = GreedySearchDecoder(model)
    return searcher, voc
Ejemplo n.º 4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-m',
                        '--mode',
                        choices={'train', 'run'},
                        help="mode to run the network")
    parser.add_argument('-cp', '--checkpoint')
    args = parser.parse_args()

    print('Saving path:', SAVE_PATH)
    checkpoint_mng = CheckpointManager(SAVE_PATH)

    checkpoint = None
    if args.checkpoint:
        print('Load checkpoint:', args.checkpoint)
        checkpoint = checkpoint_mng.load(args.checkpoint, device)

    model, voc = build_model(checkpoint)

    if args.mode == 'train':
        train(args.mode, model, voc, checkpoint, checkpoint_mng)

    elif args.mode == 'run':
        run(model, voc)
Ejemplo n.º 5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("-m",
                        "--mode",
                        choices={"train", "test", "run"},
                        help="mode to run the network")
    parser.add_argument("-cp", "--checkpoint")
    parser.add_argument("-st",
                        "--set",
                        choices={"train", "test"},
                        default="test")
    parser.add_argument("-im", "--image")
    parser.add_argument("-n", "--num", type=int, default=4)
    parser.add_argument("--train_corpus",
                        type=str,
                        default="data/dataset_celeba/train")
    parser.add_argument("--test_corpus",
                        type=str,
                        default="data/dataset_celeba/val")
    parser.add_argument("--ep", type=int, default=20, help="number of epochs")
    parser.add_argument("--save_freq",
                        type=int,
                        default=1,
                        help="save checkpoint every x epochs")
    parser.add_argument("--save_dir",
                        type=str,
                        default="checkpoints",
                        help="folder for models")
    parser.add_argument("--model_name",
                        type=str,
                        default="default",
                        help="model name")
    parser.add_argument("--imsize",
                        type=int,
                        default=448,
                        help="training image size")
    parser.add_argument("--workers",
                        type=int,
                        default=8,
                        help="number of workers")
    parser.add_argument("--bs", type=int, default=8, help="batch size")
    parser.add_argument("--lr", type=float, default=3e-4, help="learning rate")
    parser.add_argument("--optimizer",
                        type=str,
                        default="adam",
                        help="optimizer (adam; sgd)")
    parser.add_argument("--wup",
                        type=int,
                        default=0,
                        help="number of warm up epochs")
    parser.add_argument("--lr_schedule",
                        type=str,
                        default="",
                        help="learning rate schedule (multi_step_lr; cosine)")
    parser.add_argument("--print_freq",
                        type=int,
                        default=100,
                        help="print stats every x iterations")
    parser.add_argument("--grad_lambda",
                        type=float,
                        default=0.5,
                        help="gradient loss lambda")
    args = parser.parse_args()

    print("args: ", args)

    SAVE_PATH = os.path.join(DIR_PATH, args.save_dir, args.model_name)
    print("Saving path:", SAVE_PATH)
    checkpoint_mng = CheckpointManager(SAVE_PATH)

    checkpoint = None
    if args.checkpoint:
        print("Load checkpoint:", args.checkpoint)
        checkpoint = checkpoint_mng.load(args.checkpoint, device)

    model = build_model(checkpoint)

    if args.mode == "train":
        train(args, model, checkpoint, checkpoint_mng)

    elif args.mode == "test":
        test(args, model, checkpoint)

    elif args.mode == "run":
        run(args, model, checkpoint)