print('Greedy ratio:', greedy_ratio)
print('Beam size:', beam_size)
print('Output directory:', output_dir)
print('Controls:', control)
print('Temperature:', temperature)
print('Init zero:', init_zero)
print('-' * 70)

#========================================================================
# Generating
#========================================================================

state = torch.load(sess_path)
model = PerformanceRNN(**state['model_config']).to(device)
model.load_state_dict(state['model_state'])
model.eval()
print(model)
print('-' * 70)

if init_zero:
    init = torch.zeros(batch_size, model.init_dim).to(device)
else:
    init = torch.randn(batch_size, model.init_dim).to(device)

with torch.no_grad():
    if use_beam_search:
        outputs = model.beam_search(init,
                                    max_len,
                                    beam_size,
                                    controls=controls,
                                    temperature=temperature,
Esempio n. 2
0
def pretrain_discriminator(
        model_sess_path,  # load
        discriminator_sess_path,  # load + save
        batch_data_generator,  # Dataset(...).batches(...)
        discriminator_config_overwrite={},
        gradient_clipping=False,
        control_ratio=1.0,
        num_iter=-1,
        save_interval=60.0,
        discriminator_lr=0.001,
        enable_logging=False,
        auto_sample_factor=False,
        sample_factor=1.0):

    print('-' * 70)
    print('model_sess_path:', model_sess_path)
    print('discriminator_sess_path:', discriminator_sess_path)
    print('discriminator_config_overwrite:', discriminator_config_overwrite)
    print('sample_factor:', sample_factor)
    print('auto_sample_factor:', auto_sample_factor)
    print('discriminator_lr:', discriminator_lr)
    print('gradient_clipping:', gradient_clipping)
    print('control_ratio:', control_ratio)
    print('num_iter:', num_iter)
    print('save_interval:', save_interval)
    print('enable_logging:', enable_logging)
    print('-' * 70)

    # Load generator
    model_sess = torch.load(model_sess_path)
    model_config = model_sess['model_config']
    model = PerformanceRNN(**model_config).to(device)
    model.load_state_dict(model_sess['model_state'])

    print(f'Generator from "{model_sess_path}"')
    print(model)
    print('-' * 70)

    # Load discriminator and optimizer
    global discriminator_config
    try:
        discriminator_sess = torch.load(discriminator_sess_path)
        discriminator_config = discriminator_sess['discriminator_config']
        discriminator_state = discriminator_sess['discriminator_state']
        discriminator_optimizer_state = discriminator_sess[
            'discriminator_optimizer_state']
        print(f'Discriminator from "{discriminator_sess_path}"')
        discriminator_loaded = True
    except:
        print(f'New discriminator session at "{discriminator_sess_path}"')
        discriminator_config.update(discriminator_config_overwrite)
        discriminator_loaded = False

    discriminator = EventSequenceEncoder(**discriminator_config).to(device)
    optimizer = optim.Adam(discriminator.parameters(), lr=discriminator_lr)
    if discriminator_loaded:
        discriminator.load_state_dict(discriminator_state)
        optimizer.load_state_dict(discriminator_optimizer_state)

    print(discriminator)
    print(optimizer)
    print('-' * 70)

    def save_discriminator():
        print(f'Saving to "{discriminator_sess_path}"')
        torch.save(
            {
                'discriminator_config': discriminator_config,
                'discriminator_state': discriminator.state_dict(),
                'discriminator_optimizer_state': optimizer.state_dict()
            }, discriminator_sess_path)
        print('Done saving')

    # Disable gradient for generator
    for parameter in model.parameters():
        parameter.requires_grad_(False)

    model.eval()
    discriminator.train()

    loss_func = nn.BCEWithLogitsLoss()
    last_save_time = time.time()

    if enable_logging:
        from tensorboardX import SummaryWriter
        writer = SummaryWriter()

    try:
        for i, (events, controls) in enumerate(batch_data_generator):
            if i == num_iter:
                break

            steps, batch_size = events.shape

            # Prepare inputs
            events = torch.LongTensor(events).to(device)
            if np.random.random() <= control_ratio:
                controls = torch.FloatTensor(controls).to(device)
            else:
                controls = None

            init = torch.randn(batch_size, model.init_dim).to(device)

            # Predict for real event sequence
            real_events = events
            real_logit = discriminator(real_events, output_logits=True)
            real_target = torch.ones_like(real_logit).to(device)

            if auto_sample_factor:
                sample_factor = np.random.choice([
                    0.1, 0.4, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.4, 1.6, 2.0,
                    4.0, 10.0
                ])

            # Predict for fake event sequence from the generator
            fake_events = model.generate(init,
                                         steps,
                                         None,
                                         controls,
                                         greedy=0,
                                         output_type='index',
                                         temperature=sample_factor)
            fake_logit = discriminator(fake_events, output_logits=True)
            fake_target = torch.zeros_like(fake_logit).to(device)

            # Compute loss
            loss = (loss_func(real_logit, real_target) +
                    loss_func(fake_logit, fake_target)) / 2

            # Backprop
            discriminator.zero_grad()
            loss.backward()

            # Gradient clipping
            norm = utils.compute_gradient_norm(discriminator.parameters())
            if gradient_clipping:
                nn.utils.clip_grad_norm_(discriminator.parameters(),
                                         gradient_clipping)

            optimizer.step()

            # Logging
            loss = loss.item()
            norm = norm.item()
            print(f'{i} loss: {loss}, norm: {norm}, sf: {sample_factor}')
            if enable_logging:
                writer.add_scalar(f'pretrain/D/loss/all', loss, i)
                writer.add_scalar(f'pretrain/D/loss/{sample_factor}', loss, i)
                writer.add_scalar(f'pretrain/D/norm/{sample_factor}', norm, i)

            if last_save_time + save_interval < time.time():
                last_save_time = time.time()
                save_discriminator()

    except KeyboardInterrupt:
        save_discriminator()