예제 #1
0
def generate_lyrics_wrapper(genres, keywords):
    args = init_args()
    device, n_gpu = U.get_device(logger)

    # Reload the model and the tokenizer
    model = GPT2LMHeadModel.from_pretrained(args.load_model_dir)
    enc = GPT2Tokenizer.from_pretrained(args.load_model_dir)

    model.eval()
    U.set_seed(np.random.randint(0, 100))
    # U.set_seed(2)

    # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    # @                    GENERATE FROM FINE-TUNED GPT2
    # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@

    tags = genres
    keywords = keywords

    context = "[s:tags]" + tags + "[e:tags]" + \
              "[s:keywords]" + keywords + "[e:keywords]" + \
              "[s:lyrics]"
    end_token = "[e:lyrics]"
    context = enc.encode(context)
    sequence_batch = generate_lyrics(model, enc, args, context, end_token,
                                     device)

    for seq in sequence_batch:
        print(enc.decode(seq))
        print("\n---------------\n")
    return seq[0]
예제 #2
0
def main(context):
    # args = init_args()
    load_model_dir = "/Users/aw678/PycharmProjects/lyrics_generator_flask_app/tuned_models/genius_lyrics_v2/gpt2_13-11-2019@18'25/model_epoch_20"
    gen_batch = 2
    device, n_gpu = U.get_device(logger)

    # Reload the model and the tokenizer

    model = GPT2LMHeadModel.from_pretrained(load_model_dir)
    enc = GPT2Tokenizer.from_pretrained(load_model_dir)

    model.eval()
    U.set_seed(np.random.randint(0, 100))
    # U.set_seed(2)

    # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    # @                    GENERATE FROM FINE-TUNED GPT2
    # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@

    # genre = "Pop"
    # artist = "Justin Bieber"
    # year = "2015"
    # album = "Purpose"
    # song_name = "Love Yourself"
    #
    # context = "[s:genre]" + genre + "[e:genre]" + \
    #           "[s:artist]" + artist + "[e:artist]" + \
    #           "[s:year]" + year + "[e:year]" + \
    #           "[s:album]" + album + "[e:album]" + \
    #           "[s:song_name]" + song_name + "[e:song_name]" + \
    #           "[s:lyrics]"
    #
    # context = "[s:genre]" + genre + "[e:genre]" + \
    #           "[s:artist]" + artist + "[e:artist]" + \
    #           "[s:lyrics]"

    end_token = "[e:lyrics]"

    context = enc.encode(context)

    sequence_batch = generate_lyrics(model, enc, gen_batch, context, end_token,
                                     device)

    lyrics_list = []
    for indx, seq in enumerate(sequence_batch):
        lyrics_list.append(enc.decode(seq))

    return lyrics_list
def main():
    args = init_args()
    device, n_gpu = U.get_device(logger)

    # Reload the model and the tokenizer
    model = GPT2LMHeadModel.from_pretrained(args.load_model_dir)
    enc = GPT2Tokenizer.from_pretrained(args.load_model_dir)

    model.eval()
    U.set_seed(np.random.randint(0, 100))
    # U.set_seed(2)

    # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@
    # @                    GENERATE FROM FINE-TUNED GPT2
    # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@

    genre = "Pop"
    artist = "Justin Bieber"
    year = "2015"
    album = "Purpose"
    song_name = "Love Yourself"

    context = "[s:genre]" + genre + "[e:genre]" + \
              "[s:artist]" + artist + "[e:artist]" + \
              "[s:year]" + year + "[e:year]" + \
              "[s:album]" + album + "[e:album]" + \
              "[s:song_name]" + song_name + "[e:song_name]" + \
              "[s:lyrics]"

    context = "[s:genre]" + genre + "[e:genre]" + \
              "[s:artist]" + artist + "[e:artist]" + \
              "[s:lyrics]"

    end_token = "[e:lyrics]"

    context = enc.encode(context)

    sequence_batch = generate_lyrics(model, enc, args, context, end_token,
                                     device)

    for seq in sequence_batch:
        print(enc.decode(seq))
        print("\n---------------\n")
예제 #4
0
def run(args):

    set_seed(args.seed)

    set_logging(ROOT_DIR, args)
    import pprint
    logging.info(
        pprint.pformat(vars(args)) if not isinstance(args, dict) else pprint.
        pformat(args))

    # set up cuda device
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.device)
    device = torch.device('cuda')

    loader = Dev_Loader(is_divide_variance=args.is_divide_variance)

    train_loader = loader.train(batch_size=args.batch_size)
    val_loader = loader.val(batch_size=args.batch_size)

    # model = getattr(net_archs, args.net)(args).cuda()
    from xception import ModifiedXception
    model = ModifiedXception(num_classes=args.nb_class,
                             drop_rate=args.drop_rate,
                             decay=args.decay).cuda()

    if args.optimizer == 'sgd':
        optimizer = optim.SGD(model.parameters(),
                              lr=args.init_lr,
                              momentum=0.9,
                              nesterov=True)
    elif args.optimizer == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=args.init_lr,
                               weight_decay=args.l2)
    if args.lr_factor < 1.0:
        scheduler = ReduceLROnPlateau(optimizer,
                                      mode='max',
                                      verbose=True,
                                      factor=args.lr_factor,
                                      patience=args.lr_patience)

    train_hist = History(name='train')
    test_list = ['a', 'bc', 'abc']
    val_hist = dict()
    for d in test_list:
        val_hist[d] = History(name='val/{}'.format(d))

    if args.continue_run:
        ckpt_file = Reporter(exp=args.exp).select_last(
            args.ckpt_prefix[0:5]).selected_ckpt
        logging.info('continue training from {}'.format(ckpt_file))

        ckpt_dicts = torch.load(ckpt_file)

        model.load_state_dict(ckpt_dicts['model_state_dict'])
        model.cuda()

        optimizer.load_state_dict(ckpt_dicts['optimizer_state_dict'])

        start_epoch = ckpt_dicts['epoch'] + 1
    else:
        start_epoch = 1

    # checkpoint after new History, order matters
    ckpter = CheckPoint(model=model,
                        optimizer=optimizer,
                        path='{}/ckpt/{}'.format(ROOT_DIR, args.exp),
                        prefix=args.ckpt_prefix,
                        interval=1,
                        save_num=1)

    for epoch in range(start_epoch, args.run_epochs):

        train_mixup_all(train_loader,
                        model,
                        optimizer,
                        device,
                        mix_alpha=args.mix_alpha)

        train_hist.add(logs=eval_model(train_loader, model, device),
                       epoch=epoch)

        a_logs = eval_model(val_loader['a'], model, device)
        bc_logs = eval_model(val_loader['bc'], model, device)
        avg_loss = (a_logs['loss'] + bc_logs['loss']) / 2
        avg_acc = (a_logs['acc'] + bc_logs['acc']) / 2
        avg_logs = {'loss': avg_loss, 'acc': avg_acc}
        val_hist['a'].add(logs=a_logs, epoch=epoch)
        val_hist['bc'].add(logs=bc_logs, epoch=epoch)
        val_hist['abc'].add(logs=avg_logs, epoch=epoch)

        if args.lr_factor < 1.0:
            scheduler.step(val_hist['abc'].recent['acc'])

        # plotting
        if args.plot:
            train_hist.clc_plot()
            for d in test_list:
                val_hist[d].plot()

        # logging
        logging.info("Epoch{:04d},{:6},{}".format(epoch, train_hist.name,
                                                  str(train_hist.recent)))
        for d in test_list:
            logging.info("Epoch{:04d},{:6},{}".format(epoch, val_hist[d].name,
                                                      str(val_hist[d].recent)))

        ckpter.check_on(epoch=epoch,
                        monitor='acc',
                        loss_acc=val_hist['abc'].recent)

    # explicitly save last
    ckpter.save(epoch=args.run_epochs - 1,
                monitor='acc',
                loss_acc=val_hist['abc'].recent)