def train_generator(batch_size, init, events, controls): # Generator step hidden = model.init_to_hidden(init) event = model.get_primary_event(batch_size) outputs = [] generated = [] q_values = [] for step in Bar('MC Rollout').iter(range(steps)): control = controls[step].unsqueeze(0) if use_control else None output, hidden = model.forward(event, control=control, hidden=hidden) outputs.append(output) probs = model.output_fc_activation(output / mc_sample_factor) generated.append(Categorical(probs).sample()) with torch.no_grad(): if step < steps - 1: sequences = mc_rollout(generated, hidden, steps, controls[step + 1:]) mc_score = discriminator(sequences) # [mcs * b] mc_score = mc_score.view(mc_sample_size, batch_size) # [mcs, b] q_value = mc_score.mean(0, keepdim=True) # [1, batch_size] else: q_value = discriminator(torch.cat(generated, 0)) q_value = q_value.unsqueeze(0) # [1, batch_size] q_values.append(q_value) # Compute loss q_values = torch.cat(q_values, 0) # [steps, batch_size] q_mean = q_values.mean().detach() q_values = q_values - q_mean generated = torch.cat(generated, 0) # [steps, batch_size] outputs = torch.cat(outputs, 0) # [steps, batch_size, event_dim] loss = F.cross_entropy(outputs.view(-1, model.event_dim), generated.view(-1), reduce=False) loss = (loss * q_values.view(-1)).mean() # Backprop model.zero_grad() loss.backward() # Gradient clipping norm = utils.compute_gradient_norm(model.parameters()) if model_gradient_clipping: nn.utils.clip_grad_norm_(model.parameters(), model_gradient_clipping) model_optimizer.step() q_mean = q_mean.item() norm = norm.item() return q_mean, norm
def train_discriminator(batch_size, init, events, controls): # Discriminator step with torch.no_grad(): generated = model.generate(init, steps, None, controls, greedy=0, temperature=mc_sample_factor) fake_logit = discriminator(generated, output_logits=True) real_logit = discriminator(events, output_logits=True) fake_target = torch.zeros_like(fake_logit) real_target = torch.ones_like(real_logit) # Compute loss fake_loss = F.binary_cross_entropy_with_logits(fake_logit, fake_target) real_loss = F.binary_cross_entropy_with_logits(real_logit, real_target) loss = (real_loss + fake_loss) / 2 # Backprop discriminator.zero_grad() loss.backward() # Gradient clipping norm = utils.compute_gradient_norm(discriminator.parameters()) if discriminator_gradient_clipping: nn.utils.clip_grad_norm_(discriminator.parameters(), discriminator_gradient_clipping) discriminator_optimizer.step() real_loss = real_loss.item() fake_loss = fake_loss.item() loss = loss.item() norm = norm.item() return loss, real_loss, fake_loss, norm
controls = None init = torch.randn(batch_size, model.init_dim).to(device) outputs = model.generate(init, window_size, events=events[:-1], controls=controls, teacher_forcing_ratio=teacher_forcing_ratio, output_type='logit') assert outputs.shape[:2] == events.shape[:2] loss = loss_function(outputs.view(-1, event_dim), events.view(-1)) model.zero_grad() loss.backward() norm = utils.compute_gradient_norm(model.parameters()) nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() if enable_logging: writer.add_scalar('model/loss', loss.item(), iteration) writer.add_scalar('model/norm', norm.item(), iteration) print(f'iter {iteration}, loss: {loss.item()}') if time.time() - last_saving_time > saving_interval: save_model() last_saving_time = time.time() except KeyboardInterrupt:
def pretrain_discriminator( model_sess_path, # load discriminator_sess_path, # load + save batch_data_generator, # Dataset(...).batches(...) discriminator_config_overwrite={}, gradient_clipping=False, control_ratio=1.0, num_iter=-1, save_interval=60.0, discriminator_lr=0.001, enable_logging=False, auto_sample_factor=False, sample_factor=1.0): print('-' * 70) print('model_sess_path:', model_sess_path) print('discriminator_sess_path:', discriminator_sess_path) print('discriminator_config_overwrite:', discriminator_config_overwrite) print('sample_factor:', sample_factor) print('auto_sample_factor:', auto_sample_factor) print('discriminator_lr:', discriminator_lr) print('gradient_clipping:', gradient_clipping) print('control_ratio:', control_ratio) print('num_iter:', num_iter) print('save_interval:', save_interval) print('enable_logging:', enable_logging) print('-' * 70) # Load generator model_sess = torch.load(model_sess_path) model_config = model_sess['model_config'] model = PerformanceRNN(**model_config).to(device) model.load_state_dict(model_sess['model_state']) print(f'Generator from "{model_sess_path}"') print(model) print('-' * 70) # Load discriminator and optimizer global discriminator_config try: discriminator_sess = torch.load(discriminator_sess_path) discriminator_config = discriminator_sess['discriminator_config'] discriminator_state = discriminator_sess['discriminator_state'] discriminator_optimizer_state = discriminator_sess[ 'discriminator_optimizer_state'] print(f'Discriminator from "{discriminator_sess_path}"') discriminator_loaded = True except: print(f'New discriminator session at "{discriminator_sess_path}"') discriminator_config.update(discriminator_config_overwrite) discriminator_loaded = False discriminator = EventSequenceEncoder(**discriminator_config).to(device) optimizer = optim.Adam(discriminator.parameters(), lr=discriminator_lr) if discriminator_loaded: discriminator.load_state_dict(discriminator_state) optimizer.load_state_dict(discriminator_optimizer_state) print(discriminator) print(optimizer) print('-' * 70) def save_discriminator(): print(f'Saving to "{discriminator_sess_path}"') torch.save( { 'discriminator_config': discriminator_config, 'discriminator_state': discriminator.state_dict(), 'discriminator_optimizer_state': optimizer.state_dict() }, discriminator_sess_path) print('Done saving') # Disable gradient for generator for parameter in model.parameters(): parameter.requires_grad_(False) model.eval() discriminator.train() loss_func = nn.BCEWithLogitsLoss() last_save_time = time.time() if enable_logging: from tensorboardX import SummaryWriter writer = SummaryWriter() try: for i, (events, controls) in enumerate(batch_data_generator): if i == num_iter: break steps, batch_size = events.shape # Prepare inputs events = torch.LongTensor(events).to(device) if np.random.random() <= control_ratio: controls = torch.FloatTensor(controls).to(device) else: controls = None init = torch.randn(batch_size, model.init_dim).to(device) # Predict for real event sequence real_events = events real_logit = discriminator(real_events, output_logits=True) real_target = torch.ones_like(real_logit).to(device) if auto_sample_factor: sample_factor = np.random.choice([ 0.1, 0.4, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.4, 1.6, 2.0, 4.0, 10.0 ]) # Predict for fake event sequence from the generator fake_events = model.generate(init, steps, None, controls, greedy=0, output_type='index', temperature=sample_factor) fake_logit = discriminator(fake_events, output_logits=True) fake_target = torch.zeros_like(fake_logit).to(device) # Compute loss loss = (loss_func(real_logit, real_target) + loss_func(fake_logit, fake_target)) / 2 # Backprop discriminator.zero_grad() loss.backward() # Gradient clipping norm = utils.compute_gradient_norm(discriminator.parameters()) if gradient_clipping: nn.utils.clip_grad_norm_(discriminator.parameters(), gradient_clipping) optimizer.step() # Logging loss = loss.item() norm = norm.item() print(f'{i} loss: {loss}, norm: {norm}, sf: {sample_factor}') if enable_logging: writer.add_scalar(f'pretrain/D/loss/all', loss, i) writer.add_scalar(f'pretrain/D/loss/{sample_factor}', loss, i) writer.add_scalar(f'pretrain/D/norm/{sample_factor}', norm, i) if last_save_time + save_interval < time.time(): last_save_time = time.time() save_discriminator() except KeyboardInterrupt: save_discriminator()