def load_model_path(): global model_path, model_config, device, learning_rate, reset_optimizer try: param = torch.load(model_path) if 'model_config' in param and param['model_config'] != model_config: model_config = param['model_config'] print('用model的設置,不要用:') print(utils.dict2params('model_config')) model_state = param['model_state'] optimizer_state = param['model_optimizer_state'] print('參數來自:', model_path) param_loaded = True except: print('無先前參數') param_loaded = False model = PerformanceRNN(**model_config).to(device) optimizer = optim.Adam(model.parameters(), lr=learning_rate) if param_loaded: model.load_state_dict(model_state) if not reset_optimizer: optimizer.load_state_dict(optimizer_state) return model, optimizer
def load_session(): global sess_path, model_config, device, learning_rate, reset_optimizer try: sess = torch.load(sess_path) if 'model_config' in sess and sess['model_config'] != model_config: model_config = sess['model_config'] print('Use session config instead:') print(utils.dict2params(model_config)) model_state = sess['model_state'] optimizer_state = sess['model_optimizer_state'] print('Session is loaded from', sess_path) sess_loaded = True except: print('New session') sess_loaded = False model = PerformanceRNN(**model_config).to(device) optimizer = optim.Adam(model.parameters(), lr=learning_rate) if sess_loaded: model.load_state_dict(model_state) if not reset_optimizer: optimizer.load_state_dict(optimizer_state) return model, optimizer
print('Batch size:', batch_size) print('Max length:', max_len) print('Greedy ratio:', greedy_ratio) print('Beam size:', beam_size) print('Output directory:', output_dir) print('Controls:', control) print('Temperature:', temperature) print('Init zero:', init_zero) print('-' * 70) #======================================================================== # Generating #======================================================================== state = torch.load(sess_path) model = PerformanceRNN(**state['model_config']).to(device) model.load_state_dict(state['model_state']) model.eval() print(model) print('-' * 70) if init_zero: init = torch.zeros(batch_size, model.init_dim).to(device) else: init = torch.randn(batch_size, model.init_dim).to(device) with torch.no_grad(): if use_beam_search: outputs = model.beam_search(init, max_len, beam_size,
controls = torch.tensor(control.to_array(), dtype=torch.float32) controls = controls.repeat(1, batch_size, 1).to(device) control = repr(control) else: controls = None control = 'None' print('-' * 50) print('model_path = ', model_path) print('control = ', control) print('batch_size = ', batch_size) print('temperature = ', temperature) print('output_path = ', output_path) state = torch.load(model_path) model = PerformanceRNN(**state['model_config']).to(device) model.load_state_dict(state['model_state']) model.eval() print(model) print('-' * 50) init = torch.randn(batch_size, model.init_dim).to(device) with torch.no_grad(): outputs = model.generate(init, max_len, controls=controls, temperature=temperature, verbose=True) outputs = outputs.cpu().numpy().T # [batch, steps]
def pretrain_discriminator( model_sess_path, # load discriminator_sess_path, # load + save batch_data_generator, # Dataset(...).batches(...) discriminator_config_overwrite={}, gradient_clipping=False, control_ratio=1.0, num_iter=-1, save_interval=60.0, discriminator_lr=0.001, enable_logging=False, auto_sample_factor=False, sample_factor=1.0): print('-' * 70) print('model_sess_path:', model_sess_path) print('discriminator_sess_path:', discriminator_sess_path) print('discriminator_config_overwrite:', discriminator_config_overwrite) print('sample_factor:', sample_factor) print('auto_sample_factor:', auto_sample_factor) print('discriminator_lr:', discriminator_lr) print('gradient_clipping:', gradient_clipping) print('control_ratio:', control_ratio) print('num_iter:', num_iter) print('save_interval:', save_interval) print('enable_logging:', enable_logging) print('-' * 70) # Load generator model_sess = torch.load(model_sess_path) model_config = model_sess['model_config'] model = PerformanceRNN(**model_config).to(device) model.load_state_dict(model_sess['model_state']) print(f'Generator from "{model_sess_path}"') print(model) print('-' * 70) # Load discriminator and optimizer global discriminator_config try: discriminator_sess = torch.load(discriminator_sess_path) discriminator_config = discriminator_sess['discriminator_config'] discriminator_state = discriminator_sess['discriminator_state'] discriminator_optimizer_state = discriminator_sess[ 'discriminator_optimizer_state'] print(f'Discriminator from "{discriminator_sess_path}"') discriminator_loaded = True except: print(f'New discriminator session at "{discriminator_sess_path}"') discriminator_config.update(discriminator_config_overwrite) discriminator_loaded = False discriminator = EventSequenceEncoder(**discriminator_config).to(device) optimizer = optim.Adam(discriminator.parameters(), lr=discriminator_lr) if discriminator_loaded: discriminator.load_state_dict(discriminator_state) optimizer.load_state_dict(discriminator_optimizer_state) print(discriminator) print(optimizer) print('-' * 70) def save_discriminator(): print(f'Saving to "{discriminator_sess_path}"') torch.save( { 'discriminator_config': discriminator_config, 'discriminator_state': discriminator.state_dict(), 'discriminator_optimizer_state': optimizer.state_dict() }, discriminator_sess_path) print('Done saving') # Disable gradient for generator for parameter in model.parameters(): parameter.requires_grad_(False) model.eval() discriminator.train() loss_func = nn.BCEWithLogitsLoss() last_save_time = time.time() if enable_logging: from tensorboardX import SummaryWriter writer = SummaryWriter() try: for i, (events, controls) in enumerate(batch_data_generator): if i == num_iter: break steps, batch_size = events.shape # Prepare inputs events = torch.LongTensor(events).to(device) if np.random.random() <= control_ratio: controls = torch.FloatTensor(controls).to(device) else: controls = None init = torch.randn(batch_size, model.init_dim).to(device) # Predict for real event sequence real_events = events real_logit = discriminator(real_events, output_logits=True) real_target = torch.ones_like(real_logit).to(device) if auto_sample_factor: sample_factor = np.random.choice([ 0.1, 0.4, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.4, 1.6, 2.0, 4.0, 10.0 ]) # Predict for fake event sequence from the generator fake_events = model.generate(init, steps, None, controls, greedy=0, output_type='index', temperature=sample_factor) fake_logit = discriminator(fake_events, output_logits=True) fake_target = torch.zeros_like(fake_logit).to(device) # Compute loss loss = (loss_func(real_logit, real_target) + loss_func(fake_logit, fake_target)) / 2 # Backprop discriminator.zero_grad() loss.backward() # Gradient clipping norm = utils.compute_gradient_norm(discriminator.parameters()) if gradient_clipping: nn.utils.clip_grad_norm_(discriminator.parameters(), gradient_clipping) optimizer.step() # Logging loss = loss.item() norm = norm.item() print(f'{i} loss: {loss}, norm: {norm}, sf: {sample_factor}') if enable_logging: writer.add_scalar(f'pretrain/D/loss/all', loss, i) writer.add_scalar(f'pretrain/D/loss/{sample_factor}', loss, i) writer.add_scalar(f'pretrain/D/norm/{sample_factor}', norm, i) if last_save_time + save_interval < time.time(): last_save_time = time.time() save_discriminator() except KeyboardInterrupt: save_discriminator()
def train_adversarial(sess_path, batch_data_generator, model_load_path, model_optimizer_class, model_gradient_clipping, discriminator_gradient_clipping, model_learning_rate, reset_model_optimizer, discriminator_load_path, discriminator_optimizer_class, discriminator_learning_rate, reset_discriminator_optimizer, g_max_q_mean, g_min_q_mean, d_min_loss, g_max_steps, d_max_steps, mc_sample_size, mc_sample_factor, first_to_train, save_interval, control_ratio, enable_logging): if enable_logging: from tensorboardX import SummaryWriter writer = SummaryWriter() if os.path.isfile(sess_path): adv_state = torch.load(sess_path) model_config = adv_state['model_config'] model_state = adv_state['model_state'] model_optimizer_state = adv_state['model_optimizer_state'] discriminator_config = adv_state['discriminator_config'] discriminator_state = adv_state['discriminator_state'] discriminator_optimizer_state = adv_state[ 'discriminator_optimizer_state'] print('-' * 70) print('Session is loaded from', sess_path) loaded_from_session = True else: model_sess = torch.load(model_load_path) model_config = model_sess['model_config'] model_state = model_sess['model_state'] discriminator_sess = torch.load(discriminator_load_path) discriminator_config = discriminator_sess['discriminator_config'] discriminator_state = discriminator_sess['discriminator_state'] loaded_from_session = False model = PerformanceRNN(**model_config) model.load_state_dict(model_state) model.to(device).train() model_optimizer = model_optimizer_class(model.parameters(), lr=model_learning_rate) discriminator = EventSequenceEncoder(**discriminator_config) discriminator.load_state_dict(discriminator_state) discriminator.to(device).train() discriminator_optimizer = discriminator_optimizer_class( discriminator.parameters(), lr=discriminator_learning_rate) if loaded_from_session: if not reset_model_optimizer: model_optimizer.load_state_dict(model_optimizer_state) if not reset_discriminator_optimizer: discriminator_optimizer.load_state_dict( discriminator_optimizer_state) g_loss_func = nn.CrossEntropyLoss() d_loss_func = nn.BCEWithLogitsLoss(reduce=False) print('-' * 70) print('Options') print('sess_path:', sess_path) print('save_interval:', save_interval) print('batch_data_generator:', batch_data_generator) print('control_ratio:', control_ratio) print('g_max_q_mean:', g_max_q_mean) print('g_min_q_mean:', g_min_q_mean) print('d_min_loss:', d_min_loss) print('mc_sample_size:', mc_sample_size) print('mc_sample_factor:', mc_sample_factor) print('enable_logging:', enable_logging) print('model_load_path:', model_load_path) print('model_loss:', g_loss_func) print('model_optimizer_class:', model_optimizer_class) print('model_gradient_clipping:', model_gradient_clipping) print('model_learning_rate:', model_learning_rate) print('reset_model_optimizer:', reset_model_optimizer) print('discriminator_load_path:', discriminator_load_path) print('discriminator_loss:', d_loss_func) print('discriminator_optimizer_class:', discriminator_optimizer_class) print('discriminator_gradient_clipping:', discriminator_gradient_clipping) print('discriminator_learning_rate:', discriminator_learning_rate) print('reset_discriminator_optimizer:', reset_discriminator_optimizer) print('first_to_train:', first_to_train) print('-' * 70) print( f'Generator from "{sess_path if loaded_from_session else model_load_path}"' ) print(model) print(model_optimizer) print('-' * 70) print( f'Discriminator from "{sess_path if loaded_from_session else discriminator_load_path}"' ) print(discriminator) print(discriminator_optimizer) print('-' * 70) def save(): print(f'Saving to "{sess_path}"') torch.save( { 'model_config': model_config, 'model_state': model.state_dict(), 'model_optimizer_state': model_optimizer.state_dict(), 'discriminator_config': discriminator_config, 'discriminator_state': discriminator.state_dict(), 'discriminator_optimizer_state': discriminator_optimizer.state_dict() }, sess_path) print('Done saving') def mc_rollout(generated, hidden, total_steps, controls=None): # generated: [t, batch_size] # hidden: [n_layers, batch_size, hidden_dim] # controls: [total_steps - t, batch_size, control_dim] generated = torch.cat(generated, 0) generated_steps, batch_size = generated.shape # t, b steps = total_steps - generated_steps # s generated = generated.unsqueeze(1) # [t, 1, b] generated = generated.repeat(1, mc_sample_size, 1) # [t, mcs, b] generated = generated.view(generated_steps, -1) # [t, mcs * b] hidden = hidden.unsqueeze(1).repeat(1, mc_sample_size, 1, 1) hidden = hidden.view(model.gru_layers, -1, model.hidden_dim) if controls is not None: assert controls.shape == (steps, batch_size, model.control_dim) controls = controls.unsqueeze(1) # [s, 1, b, c] controls = controls.repeat(1, mc_sample_size, 1, 1) # [s, mcs, b, c] controls = controls.view(steps, -1, model.control_dim) # [s, mcs * b, c] event = generated[-1].unsqueeze(0) # [1, mcs * b] control = None # default when controls is None outputs = [] for i in range(steps): if controls is not None: control = controls[i].unsqueeze(0) # [1, mcs * b, c] output, hidden = model.forward(event, control=control, hidden=hidden) probs = model.output_fc_activation(output / mc_sample_factor) event = Categorical(probs).sample() # [1, mcs * b] outputs.append(event) sequences = torch.cat([generated, *outputs], 0) assert sequences.shape == (total_steps, mc_sample_size * batch_size) return sequences def train_generator(batch_size, init, events, controls): # Generator step hidden = model.init_to_hidden(init) event = model.get_primary_event(batch_size) outputs = [] generated = [] q_values = [] for step in Bar('MC Rollout').iter(range(steps)): control = controls[step].unsqueeze(0) if use_control else None output, hidden = model.forward(event, control=control, hidden=hidden) outputs.append(output) probs = model.output_fc_activation(output / mc_sample_factor) generated.append(Categorical(probs).sample()) with torch.no_grad(): if step < steps - 1: sequences = mc_rollout(generated, hidden, steps, controls[step + 1:]) mc_score = discriminator(sequences) # [mcs * b] mc_score = mc_score.view(mc_sample_size, batch_size) # [mcs, b] q_value = mc_score.mean(0, keepdim=True) # [1, batch_size] else: q_value = discriminator(torch.cat(generated, 0)) q_value = q_value.unsqueeze(0) # [1, batch_size] q_values.append(q_value) # Compute loss q_values = torch.cat(q_values, 0) # [steps, batch_size] q_mean = q_values.mean().detach() q_values = q_values - q_mean generated = torch.cat(generated, 0) # [steps, batch_size] outputs = torch.cat(outputs, 0) # [steps, batch_size, event_dim] loss = F.cross_entropy(outputs.view(-1, model.event_dim), generated.view(-1), reduce=False) loss = (loss * q_values.view(-1)).mean() # Backprop model.zero_grad() loss.backward() # Gradient clipping norm = utils.compute_gradient_norm(model.parameters()) if model_gradient_clipping: nn.utils.clip_grad_norm_(model.parameters(), model_gradient_clipping) model_optimizer.step() q_mean = q_mean.item() norm = norm.item() return q_mean, norm def train_discriminator(batch_size, init, events, controls): # Discriminator step with torch.no_grad(): generated = model.generate(init, steps, None, controls, greedy=0, temperature=mc_sample_factor) fake_logit = discriminator(generated, output_logits=True) real_logit = discriminator(events, output_logits=True) fake_target = torch.zeros_like(fake_logit) real_target = torch.ones_like(real_logit) # Compute loss fake_loss = F.binary_cross_entropy_with_logits(fake_logit, fake_target) real_loss = F.binary_cross_entropy_with_logits(real_logit, real_target) loss = (real_loss + fake_loss) / 2 # Backprop discriminator.zero_grad() loss.backward() # Gradient clipping norm = utils.compute_gradient_norm(discriminator.parameters()) if discriminator_gradient_clipping: nn.utils.clip_grad_norm_(discriminator.parameters(), discriminator_gradient_clipping) discriminator_optimizer.step() real_loss = real_loss.item() fake_loss = fake_loss.item() loss = loss.item() norm = norm.item() return loss, real_loss, fake_loss, norm try: last_save_time = time.time() step_for = first_to_train g_steps = 0 d_steps = 0 for i, (events, controls) in enumerate(batch_data_generator): steps, batch_size = events.shape init = torch.randn(batch_size, model.init_dim).to(device) events = torch.LongTensor(events).to(device) use_control = np.random.random() <= control_ratio controls = torch.FloatTensor(controls).to( device) if use_control else None if step_for == 'G': q_mean, norm = train_generator(batch_size, init, events, controls) g_steps += 1 print(f'{i} (G-step) Q_mean: {q_mean}, norm: {norm}') if enable_logging: writer.add_scalar('adversarial/G/Q_mean', q_mean, i) writer.add_scalar('adversarial/G/norm', norm, i) if q_mean < g_min_q_mean: print(f'Q is too small: {q_mean}, exiting') raise KeyboardInterrupt if q_mean > g_max_q_mean or (g_max_steps and g_steps >= g_max_steps): step_for = 'D' d_steps = 0 if step_for == 'D': loss, real_loss, fake_loss, norm = train_discriminator( batch_size, init, events, controls) d_steps += 1 print( f'{i} (D-step) loss: {loss} (real: {real_loss}, fake: {fake_loss}), norm: {norm}' ) if enable_logging: writer.add_scalar('adversarial/D/loss', loss, i) writer.add_scalar('adversarial/D/norm', norm, i) if fake_loss <= real_loss < d_min_loss or ( d_max_steps and d_steps >= d_max_steps): step_for = 'G' g_steps = 0 if last_save_time + save_interval < time.time(): last_save_time = time.time() save() except KeyboardInterrupt: save()