Example #1
0
def load_model_path():
    global model_path, model_config, device, learning_rate, reset_optimizer

    try:
        param = torch.load(model_path)
        if 'model_config' in param and param['model_config'] != model_config:
            model_config = param['model_config']
            print('用model的設置,不要用:')
            print(utils.dict2params('model_config'))
        model_state = param['model_state']
        optimizer_state = param['model_optimizer_state']
        print('參數來自:', model_path)
        param_loaded = True

    except:
        print('無先前參數')
        param_loaded = False

    model = PerformanceRNN(**model_config).to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    if param_loaded:
        model.load_state_dict(model_state)
        if not reset_optimizer:
            optimizer.load_state_dict(optimizer_state)

    return model, optimizer
Example #2
0
def load_session():
    global sess_path, model_config, device, learning_rate, reset_optimizer
    try:
        sess = torch.load(sess_path)
        if 'model_config' in sess and sess['model_config'] != model_config:
            model_config = sess['model_config']
            print('Use session config instead:')
            print(utils.dict2params(model_config))
        model_state = sess['model_state']
        optimizer_state = sess['model_optimizer_state']
        print('Session is loaded from', sess_path)
        sess_loaded = True
    except:
        print('New session')
        sess_loaded = False
    model = PerformanceRNN(**model_config).to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    if sess_loaded:
        model.load_state_dict(model_state)
        if not reset_optimizer:
            optimizer.load_state_dict(optimizer_state)
    return model, optimizer
print('Batch size:', batch_size)
print('Max length:', max_len)
print('Greedy ratio:', greedy_ratio)
print('Beam size:', beam_size)
print('Output directory:', output_dir)
print('Controls:', control)
print('Temperature:', temperature)
print('Init zero:', init_zero)
print('-' * 70)

#========================================================================
# Generating
#========================================================================

state = torch.load(sess_path)
model = PerformanceRNN(**state['model_config']).to(device)
model.load_state_dict(state['model_state'])
model.eval()
print(model)
print('-' * 70)

if init_zero:
    init = torch.zeros(batch_size, model.init_dim).to(device)
else:
    init = torch.randn(batch_size, model.init_dim).to(device)

with torch.no_grad():
    if use_beam_search:
        outputs = model.beam_search(init,
                                    max_len,
                                    beam_size,
        controls = torch.tensor(control.to_array(), dtype=torch.float32)
        controls = controls.repeat(1, batch_size, 1).to(device)
        control = repr(control)
else:
    controls = None
    control = 'None'

print('-' * 50)
print('model_path = ', model_path)
print('control = ', control)
print('batch_size = ', batch_size)
print('temperature = ', temperature)
print('output_path = ', output_path)

state = torch.load(model_path)
model = PerformanceRNN(**state['model_config']).to(device)
model.load_state_dict(state['model_state'])
model.eval()
print(model)
print('-' * 50)

init = torch.randn(batch_size, model.init_dim).to(device)

with torch.no_grad():
    outputs = model.generate(init,
                             max_len,
                             controls=controls,
                             temperature=temperature,
                             verbose=True)

outputs = outputs.cpu().numpy().T  # [batch, steps]
Example #5
0
def pretrain_discriminator(
        model_sess_path,  # load
        discriminator_sess_path,  # load + save
        batch_data_generator,  # Dataset(...).batches(...)
        discriminator_config_overwrite={},
        gradient_clipping=False,
        control_ratio=1.0,
        num_iter=-1,
        save_interval=60.0,
        discriminator_lr=0.001,
        enable_logging=False,
        auto_sample_factor=False,
        sample_factor=1.0):

    print('-' * 70)
    print('model_sess_path:', model_sess_path)
    print('discriminator_sess_path:', discriminator_sess_path)
    print('discriminator_config_overwrite:', discriminator_config_overwrite)
    print('sample_factor:', sample_factor)
    print('auto_sample_factor:', auto_sample_factor)
    print('discriminator_lr:', discriminator_lr)
    print('gradient_clipping:', gradient_clipping)
    print('control_ratio:', control_ratio)
    print('num_iter:', num_iter)
    print('save_interval:', save_interval)
    print('enable_logging:', enable_logging)
    print('-' * 70)

    # Load generator
    model_sess = torch.load(model_sess_path)
    model_config = model_sess['model_config']
    model = PerformanceRNN(**model_config).to(device)
    model.load_state_dict(model_sess['model_state'])

    print(f'Generator from "{model_sess_path}"')
    print(model)
    print('-' * 70)

    # Load discriminator and optimizer
    global discriminator_config
    try:
        discriminator_sess = torch.load(discriminator_sess_path)
        discriminator_config = discriminator_sess['discriminator_config']
        discriminator_state = discriminator_sess['discriminator_state']
        discriminator_optimizer_state = discriminator_sess[
            'discriminator_optimizer_state']
        print(f'Discriminator from "{discriminator_sess_path}"')
        discriminator_loaded = True
    except:
        print(f'New discriminator session at "{discriminator_sess_path}"')
        discriminator_config.update(discriminator_config_overwrite)
        discriminator_loaded = False

    discriminator = EventSequenceEncoder(**discriminator_config).to(device)
    optimizer = optim.Adam(discriminator.parameters(), lr=discriminator_lr)
    if discriminator_loaded:
        discriminator.load_state_dict(discriminator_state)
        optimizer.load_state_dict(discriminator_optimizer_state)

    print(discriminator)
    print(optimizer)
    print('-' * 70)

    def save_discriminator():
        print(f'Saving to "{discriminator_sess_path}"')
        torch.save(
            {
                'discriminator_config': discriminator_config,
                'discriminator_state': discriminator.state_dict(),
                'discriminator_optimizer_state': optimizer.state_dict()
            }, discriminator_sess_path)
        print('Done saving')

    # Disable gradient for generator
    for parameter in model.parameters():
        parameter.requires_grad_(False)

    model.eval()
    discriminator.train()

    loss_func = nn.BCEWithLogitsLoss()
    last_save_time = time.time()

    if enable_logging:
        from tensorboardX import SummaryWriter
        writer = SummaryWriter()

    try:
        for i, (events, controls) in enumerate(batch_data_generator):
            if i == num_iter:
                break

            steps, batch_size = events.shape

            # Prepare inputs
            events = torch.LongTensor(events).to(device)
            if np.random.random() <= control_ratio:
                controls = torch.FloatTensor(controls).to(device)
            else:
                controls = None

            init = torch.randn(batch_size, model.init_dim).to(device)

            # Predict for real event sequence
            real_events = events
            real_logit = discriminator(real_events, output_logits=True)
            real_target = torch.ones_like(real_logit).to(device)

            if auto_sample_factor:
                sample_factor = np.random.choice([
                    0.1, 0.4, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.4, 1.6, 2.0,
                    4.0, 10.0
                ])

            # Predict for fake event sequence from the generator
            fake_events = model.generate(init,
                                         steps,
                                         None,
                                         controls,
                                         greedy=0,
                                         output_type='index',
                                         temperature=sample_factor)
            fake_logit = discriminator(fake_events, output_logits=True)
            fake_target = torch.zeros_like(fake_logit).to(device)

            # Compute loss
            loss = (loss_func(real_logit, real_target) +
                    loss_func(fake_logit, fake_target)) / 2

            # Backprop
            discriminator.zero_grad()
            loss.backward()

            # Gradient clipping
            norm = utils.compute_gradient_norm(discriminator.parameters())
            if gradient_clipping:
                nn.utils.clip_grad_norm_(discriminator.parameters(),
                                         gradient_clipping)

            optimizer.step()

            # Logging
            loss = loss.item()
            norm = norm.item()
            print(f'{i} loss: {loss}, norm: {norm}, sf: {sample_factor}')
            if enable_logging:
                writer.add_scalar(f'pretrain/D/loss/all', loss, i)
                writer.add_scalar(f'pretrain/D/loss/{sample_factor}', loss, i)
                writer.add_scalar(f'pretrain/D/norm/{sample_factor}', norm, i)

            if last_save_time + save_interval < time.time():
                last_save_time = time.time()
                save_discriminator()

    except KeyboardInterrupt:
        save_discriminator()
Example #6
0
def train_adversarial(sess_path, batch_data_generator, model_load_path,
                      model_optimizer_class, model_gradient_clipping,
                      discriminator_gradient_clipping, model_learning_rate,
                      reset_model_optimizer, discriminator_load_path,
                      discriminator_optimizer_class,
                      discriminator_learning_rate,
                      reset_discriminator_optimizer, g_max_q_mean,
                      g_min_q_mean, d_min_loss, g_max_steps, d_max_steps,
                      mc_sample_size, mc_sample_factor, first_to_train,
                      save_interval, control_ratio, enable_logging):

    if enable_logging:
        from tensorboardX import SummaryWriter
        writer = SummaryWriter()

    if os.path.isfile(sess_path):
        adv_state = torch.load(sess_path)
        model_config = adv_state['model_config']
        model_state = adv_state['model_state']
        model_optimizer_state = adv_state['model_optimizer_state']
        discriminator_config = adv_state['discriminator_config']
        discriminator_state = adv_state['discriminator_state']
        discriminator_optimizer_state = adv_state[
            'discriminator_optimizer_state']
        print('-' * 70)
        print('Session is loaded from', sess_path)
        loaded_from_session = True

    else:
        model_sess = torch.load(model_load_path)
        model_config = model_sess['model_config']
        model_state = model_sess['model_state']
        discriminator_sess = torch.load(discriminator_load_path)
        discriminator_config = discriminator_sess['discriminator_config']
        discriminator_state = discriminator_sess['discriminator_state']
        loaded_from_session = False

    model = PerformanceRNN(**model_config)
    model.load_state_dict(model_state)
    model.to(device).train()
    model_optimizer = model_optimizer_class(model.parameters(),
                                            lr=model_learning_rate)

    discriminator = EventSequenceEncoder(**discriminator_config)
    discriminator.load_state_dict(discriminator_state)
    discriminator.to(device).train()
    discriminator_optimizer = discriminator_optimizer_class(
        discriminator.parameters(), lr=discriminator_learning_rate)

    if loaded_from_session:
        if not reset_model_optimizer:
            model_optimizer.load_state_dict(model_optimizer_state)
        if not reset_discriminator_optimizer:
            discriminator_optimizer.load_state_dict(
                discriminator_optimizer_state)

    g_loss_func = nn.CrossEntropyLoss()
    d_loss_func = nn.BCEWithLogitsLoss(reduce=False)

    print('-' * 70)
    print('Options')
    print('sess_path:', sess_path)
    print('save_interval:', save_interval)
    print('batch_data_generator:', batch_data_generator)
    print('control_ratio:', control_ratio)
    print('g_max_q_mean:', g_max_q_mean)
    print('g_min_q_mean:', g_min_q_mean)
    print('d_min_loss:', d_min_loss)
    print('mc_sample_size:', mc_sample_size)
    print('mc_sample_factor:', mc_sample_factor)
    print('enable_logging:', enable_logging)
    print('model_load_path:', model_load_path)
    print('model_loss:', g_loss_func)
    print('model_optimizer_class:', model_optimizer_class)
    print('model_gradient_clipping:', model_gradient_clipping)
    print('model_learning_rate:', model_learning_rate)
    print('reset_model_optimizer:', reset_model_optimizer)
    print('discriminator_load_path:', discriminator_load_path)
    print('discriminator_loss:', d_loss_func)
    print('discriminator_optimizer_class:', discriminator_optimizer_class)
    print('discriminator_gradient_clipping:', discriminator_gradient_clipping)
    print('discriminator_learning_rate:', discriminator_learning_rate)
    print('reset_discriminator_optimizer:', reset_discriminator_optimizer)
    print('first_to_train:', first_to_train)
    print('-' * 70)
    print(
        f'Generator from "{sess_path if loaded_from_session else model_load_path}"'
    )
    print(model)
    print(model_optimizer)
    print('-' * 70)
    print(
        f'Discriminator from "{sess_path if loaded_from_session else discriminator_load_path}"'
    )
    print(discriminator)
    print(discriminator_optimizer)
    print('-' * 70)

    def save():
        print(f'Saving to "{sess_path}"')
        torch.save(
            {
                'model_config':
                model_config,
                'model_state':
                model.state_dict(),
                'model_optimizer_state':
                model_optimizer.state_dict(),
                'discriminator_config':
                discriminator_config,
                'discriminator_state':
                discriminator.state_dict(),
                'discriminator_optimizer_state':
                discriminator_optimizer.state_dict()
            }, sess_path)
        print('Done saving')

    def mc_rollout(generated, hidden, total_steps, controls=None):
        # generated: [t, batch_size]
        # hidden: [n_layers, batch_size, hidden_dim]
        # controls: [total_steps - t, batch_size, control_dim]
        generated = torch.cat(generated, 0)
        generated_steps, batch_size = generated.shape  # t, b
        steps = total_steps - generated_steps  # s

        generated = generated.unsqueeze(1)  # [t, 1, b]
        generated = generated.repeat(1, mc_sample_size, 1)  # [t, mcs, b]
        generated = generated.view(generated_steps, -1)  # [t, mcs * b]

        hidden = hidden.unsqueeze(1).repeat(1, mc_sample_size, 1, 1)
        hidden = hidden.view(model.gru_layers, -1, model.hidden_dim)

        if controls is not None:
            assert controls.shape == (steps, batch_size, model.control_dim)
            controls = controls.unsqueeze(1)  # [s, 1, b, c]
            controls = controls.repeat(1, mc_sample_size, 1,
                                       1)  # [s, mcs, b, c]
            controls = controls.view(steps, -1,
                                     model.control_dim)  # [s, mcs * b, c]

        event = generated[-1].unsqueeze(0)  # [1, mcs * b]
        control = None  # default when controls is None
        outputs = []

        for i in range(steps):
            if controls is not None:
                control = controls[i].unsqueeze(0)  # [1, mcs * b, c]

            output, hidden = model.forward(event,
                                           control=control,
                                           hidden=hidden)
            probs = model.output_fc_activation(output / mc_sample_factor)
            event = Categorical(probs).sample()  # [1, mcs * b]
            outputs.append(event)

        sequences = torch.cat([generated, *outputs], 0)
        assert sequences.shape == (total_steps, mc_sample_size * batch_size)
        return sequences

    def train_generator(batch_size, init, events, controls):
        # Generator step
        hidden = model.init_to_hidden(init)
        event = model.get_primary_event(batch_size)
        outputs = []
        generated = []
        q_values = []

        for step in Bar('MC Rollout').iter(range(steps)):
            control = controls[step].unsqueeze(0) if use_control else None
            output, hidden = model.forward(event,
                                           control=control,
                                           hidden=hidden)
            outputs.append(output)
            probs = model.output_fc_activation(output / mc_sample_factor)
            generated.append(Categorical(probs).sample())

            with torch.no_grad():
                if step < steps - 1:
                    sequences = mc_rollout(generated, hidden, steps,
                                           controls[step + 1:])
                    mc_score = discriminator(sequences)  # [mcs * b]
                    mc_score = mc_score.view(mc_sample_size,
                                             batch_size)  # [mcs, b]
                    q_value = mc_score.mean(0, keepdim=True)  # [1, batch_size]

                else:
                    q_value = discriminator(torch.cat(generated, 0))
                    q_value = q_value.unsqueeze(0)  # [1, batch_size]

                q_values.append(q_value)

        # Compute loss
        q_values = torch.cat(q_values, 0)  # [steps, batch_size]
        q_mean = q_values.mean().detach()
        q_values = q_values - q_mean
        generated = torch.cat(generated, 0)  # [steps, batch_size]
        outputs = torch.cat(outputs, 0)  # [steps, batch_size, event_dim]
        loss = F.cross_entropy(outputs.view(-1, model.event_dim),
                               generated.view(-1),
                               reduce=False)
        loss = (loss * q_values.view(-1)).mean()

        # Backprop
        model.zero_grad()
        loss.backward()

        # Gradient clipping
        norm = utils.compute_gradient_norm(model.parameters())
        if model_gradient_clipping:
            nn.utils.clip_grad_norm_(model.parameters(),
                                     model_gradient_clipping)

        model_optimizer.step()

        q_mean = q_mean.item()
        norm = norm.item()
        return q_mean, norm

    def train_discriminator(batch_size, init, events, controls):
        # Discriminator step
        with torch.no_grad():
            generated = model.generate(init,
                                       steps,
                                       None,
                                       controls,
                                       greedy=0,
                                       temperature=mc_sample_factor)

        fake_logit = discriminator(generated, output_logits=True)
        real_logit = discriminator(events, output_logits=True)
        fake_target = torch.zeros_like(fake_logit)
        real_target = torch.ones_like(real_logit)

        # Compute loss
        fake_loss = F.binary_cross_entropy_with_logits(fake_logit, fake_target)
        real_loss = F.binary_cross_entropy_with_logits(real_logit, real_target)
        loss = (real_loss + fake_loss) / 2

        # Backprop
        discriminator.zero_grad()
        loss.backward()

        # Gradient clipping
        norm = utils.compute_gradient_norm(discriminator.parameters())
        if discriminator_gradient_clipping:
            nn.utils.clip_grad_norm_(discriminator.parameters(),
                                     discriminator_gradient_clipping)

        discriminator_optimizer.step()

        real_loss = real_loss.item()
        fake_loss = fake_loss.item()
        loss = loss.item()
        norm = norm.item()
        return loss, real_loss, fake_loss, norm

    try:
        last_save_time = time.time()
        step_for = first_to_train
        g_steps = 0
        d_steps = 0

        for i, (events, controls) in enumerate(batch_data_generator):
            steps, batch_size = events.shape
            init = torch.randn(batch_size, model.init_dim).to(device)
            events = torch.LongTensor(events).to(device)

            use_control = np.random.random() <= control_ratio
            controls = torch.FloatTensor(controls).to(
                device) if use_control else None

            if step_for == 'G':
                q_mean, norm = train_generator(batch_size, init, events,
                                               controls)
                g_steps += 1

                print(f'{i} (G-step) Q_mean: {q_mean}, norm: {norm}')
                if enable_logging:
                    writer.add_scalar('adversarial/G/Q_mean', q_mean, i)
                    writer.add_scalar('adversarial/G/norm', norm, i)

                if q_mean < g_min_q_mean:
                    print(f'Q is too small: {q_mean}, exiting')
                    raise KeyboardInterrupt

                if q_mean > g_max_q_mean or (g_max_steps
                                             and g_steps >= g_max_steps):
                    step_for = 'D'
                    d_steps = 0

            if step_for == 'D':
                loss, real_loss, fake_loss, norm = train_discriminator(
                    batch_size, init, events, controls)
                d_steps += 1

                print(
                    f'{i} (D-step) loss: {loss} (real: {real_loss}, fake: {fake_loss}), norm: {norm}'
                )
                if enable_logging:
                    writer.add_scalar('adversarial/D/loss', loss, i)
                    writer.add_scalar('adversarial/D/norm', norm, i)

                if fake_loss <= real_loss < d_min_loss or (
                        d_max_steps and d_steps >= d_max_steps):
                    step_for = 'G'
                    g_steps = 0

            if last_save_time + save_interval < time.time():
                last_save_time = time.time()
                save()

    except KeyboardInterrupt:
        save()