def train(): encoder = dataset.Encoder( message_flags.flattened_message_size()).to(device) decoder = dataset.Decoder( message_flags.flattened_message_size()).to(device) all_params = list(encoder.parameters()) + list(decoder.parameters()) optimizer = optim.Adam(all_params, weight_decay=1e-5) losses = [] encoder.train() decoder.train() for iter_idx, batch in enumerate( util.batch_iterator(dataset.get_states_and_actions(), 20)): iteration = iter_idx + 1 states = [s for s, t in batch] targets = [t for s, t in batch] optimizer.zero_grad() state_variable = dataset.state_to_variable_batch(states).to(device) target_variable = dataset.output_to_variable_batch(targets, states).to(device) encoder_output = encoder.forward(state_variable, target_variable) decoder_input = encoder_output if FLAGS.continuous_message else discrete_util.discrete_transformation( encoder_output) prediction = decoder.forward(state_variable, decoder_input, target_variable) prediction_loss = dataset.loss(prediction, target_variable) if FLAGS.continuous_message: loss = prediction_loss else: loss = prediction_loss + FLAGS.pretrain_kl_weight * discrete_util.kl_flattened( encoder_output) loss.backward() optimizer.step() losses.append(loss.item()) if iteration % 1000 == 0: print('===== iteration %s, average loss %s =====' % (iteration, np.mean(losses))) losses = [] torch.save(encoder.state_dict(), FLAGS.pretrain_prefix + 'encoder_parameters.pt') torch.save(decoder.state_dict(), FLAGS.pretrain_prefix + 'decoder_parameters.pt') print('reconstruction accuracy:', reconstruction_accuracy(encoder, decoder)) if FLAGS.pretrain_iterations is not None and iteration >= FLAGS.pretrain_iterations: break
def __init__(self): self.vocab = Vocabulary() self.language_module = dataset.LSTMLanguageModule( message_flags.flattened_message_size(), self.vocab.get_vocab_size()).to(device) self.decoder = dataset.Decoder( message_flags.flattened_message_size()).to(device) all_params = list(self.language_module.parameters()) + list( self.decoder.parameters()) self.optimizer = optim.Adam(all_params, weight_decay=1e-5) self.training_examples = []
def __init__(self): self.vocab = Vocabulary() self.language_module = dataset.LSTMLanguageModule( message_flags.flattened_message_size(), self.vocab.get_vocab_size()).to(device) self.training_examples = [] self.encoder = pretrain.load_saved_encoder().to(device) self.encoder.eval() self.decoder = pretrain.load_saved_decoder().to(device) self.decoder.eval() params_to_train = list(self.language_module.parameters()) if FLAGS.model_train_decoder: params_to_train.extend(list(self.decoder.parameters())) self.optimizer = optim.Adam(params_to_train, weight_decay=1e-5)
def load_saved_decoder(): decoder = dataset.Decoder(message_flags.flattened_message_size()) decoder.load_state_dict( torch.load(FLAGS.pretrain_prefix + 'decoder_parameters.pt')) return decoder