print('Greedy ratio:', greedy_ratio) print('Beam size:', beam_size) print('Output directory:', output_dir) print('Controls:', control) print('Temperature:', temperature) print('Init zero:', init_zero) print('-' * 70) #======================================================================== # Generating #======================================================================== state = torch.load(sess_path) model = PerformanceRNN(**state['model_config']).to(device) model.load_state_dict(state['model_state']) model.eval() print(model) print('-' * 70) if init_zero: init = torch.zeros(batch_size, model.init_dim).to(device) else: init = torch.randn(batch_size, model.init_dim).to(device) with torch.no_grad(): if use_beam_search: outputs = model.beam_search(init, max_len, beam_size, controls=controls, temperature=temperature,
def pretrain_discriminator( model_sess_path, # load discriminator_sess_path, # load + save batch_data_generator, # Dataset(...).batches(...) discriminator_config_overwrite={}, gradient_clipping=False, control_ratio=1.0, num_iter=-1, save_interval=60.0, discriminator_lr=0.001, enable_logging=False, auto_sample_factor=False, sample_factor=1.0): print('-' * 70) print('model_sess_path:', model_sess_path) print('discriminator_sess_path:', discriminator_sess_path) print('discriminator_config_overwrite:', discriminator_config_overwrite) print('sample_factor:', sample_factor) print('auto_sample_factor:', auto_sample_factor) print('discriminator_lr:', discriminator_lr) print('gradient_clipping:', gradient_clipping) print('control_ratio:', control_ratio) print('num_iter:', num_iter) print('save_interval:', save_interval) print('enable_logging:', enable_logging) print('-' * 70) # Load generator model_sess = torch.load(model_sess_path) model_config = model_sess['model_config'] model = PerformanceRNN(**model_config).to(device) model.load_state_dict(model_sess['model_state']) print(f'Generator from "{model_sess_path}"') print(model) print('-' * 70) # Load discriminator and optimizer global discriminator_config try: discriminator_sess = torch.load(discriminator_sess_path) discriminator_config = discriminator_sess['discriminator_config'] discriminator_state = discriminator_sess['discriminator_state'] discriminator_optimizer_state = discriminator_sess[ 'discriminator_optimizer_state'] print(f'Discriminator from "{discriminator_sess_path}"') discriminator_loaded = True except: print(f'New discriminator session at "{discriminator_sess_path}"') discriminator_config.update(discriminator_config_overwrite) discriminator_loaded = False discriminator = EventSequenceEncoder(**discriminator_config).to(device) optimizer = optim.Adam(discriminator.parameters(), lr=discriminator_lr) if discriminator_loaded: discriminator.load_state_dict(discriminator_state) optimizer.load_state_dict(discriminator_optimizer_state) print(discriminator) print(optimizer) print('-' * 70) def save_discriminator(): print(f'Saving to "{discriminator_sess_path}"') torch.save( { 'discriminator_config': discriminator_config, 'discriminator_state': discriminator.state_dict(), 'discriminator_optimizer_state': optimizer.state_dict() }, discriminator_sess_path) print('Done saving') # Disable gradient for generator for parameter in model.parameters(): parameter.requires_grad_(False) model.eval() discriminator.train() loss_func = nn.BCEWithLogitsLoss() last_save_time = time.time() if enable_logging: from tensorboardX import SummaryWriter writer = SummaryWriter() try: for i, (events, controls) in enumerate(batch_data_generator): if i == num_iter: break steps, batch_size = events.shape # Prepare inputs events = torch.LongTensor(events).to(device) if np.random.random() <= control_ratio: controls = torch.FloatTensor(controls).to(device) else: controls = None init = torch.randn(batch_size, model.init_dim).to(device) # Predict for real event sequence real_events = events real_logit = discriminator(real_events, output_logits=True) real_target = torch.ones_like(real_logit).to(device) if auto_sample_factor: sample_factor = np.random.choice([ 0.1, 0.4, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.4, 1.6, 2.0, 4.0, 10.0 ]) # Predict for fake event sequence from the generator fake_events = model.generate(init, steps, None, controls, greedy=0, output_type='index', temperature=sample_factor) fake_logit = discriminator(fake_events, output_logits=True) fake_target = torch.zeros_like(fake_logit).to(device) # Compute loss loss = (loss_func(real_logit, real_target) + loss_func(fake_logit, fake_target)) / 2 # Backprop discriminator.zero_grad() loss.backward() # Gradient clipping norm = utils.compute_gradient_norm(discriminator.parameters()) if gradient_clipping: nn.utils.clip_grad_norm_(discriminator.parameters(), gradient_clipping) optimizer.step() # Logging loss = loss.item() norm = norm.item() print(f'{i} loss: {loss}, norm: {norm}, sf: {sample_factor}') if enable_logging: writer.add_scalar(f'pretrain/D/loss/all', loss, i) writer.add_scalar(f'pretrain/D/loss/{sample_factor}', loss, i) writer.add_scalar(f'pretrain/D/norm/{sample_factor}', norm, i) if last_save_time + save_interval < time.time(): last_save_time = time.time() save_discriminator() except KeyboardInterrupt: save_discriminator()