def parrot_initialization_rgc(dataset, emb_path, dc=None, encoder=None, dddqn=None): ''' Trains the rgc to repeat the input ''' # TODO save optimizer if dc is None: dc = DataContainer(dataset, emb_path) dc.prepare_data() x_batch, y_parrot_batch, sl_batch = u.to_batch(dc.x, dc.y_parrot_padded, dc.sl, batch_size=dc.batch_size) # initialize rnn cell of the encoder and the dddqn rep = input('Load RNN cell pretrained for the encoder & dddqn? (y or n): ') if encoder is None: encoder = EncoderRNN(num_units=256) if rep == 'y' or rep == '': encoder.load(name='EncoderRNN-0') else: choose_best_rnn_pretrained(encoder, encoder.encoder_cell, dc, search_size=1, multiprocessed=False) # we do not need to train the dddqn rnn layer since we already trained the encoder rnn layer # we just have to initialize the dddqn rnn layer weights with the ones from the encoder if dddqn is None: dddqn = DDDQN(dc.word2idx, dc.idx2word, dc.idx2emb) u.init_rnn_layer(dddqn.lstm) u.update_layer(dddqn.lstm, encoder.encoder_cell) # define the loss function used to pretrain the rgc def get_loss(encoder, dddqn, epoch, x, y, sl, sos, max_steps, verbose=True): preds, logits, _, _, _ = pu.full_encoder_dddqn_pass(x, sl, encoder, dddqn, sos, max_steps, training=True) logits = tf.nn.softmax(logits) # normalize logits between 0 & 1 to allow training through cross-entropy sl = [end_idx + 1 for end_idx in sl] # sl = [len(sequence)-1, ...] => +1 to get the len loss = u.cross_entropy_cost(logits, y, sequence_lengths=sl) if verbose: acc_words, acc_sentences = u.get_acc_word_seq(logits, y, sl) logging.info('Epoch {} -> loss = {} | acc_words = {} | acc_sentences = {}'.format(epoch, loss, acc_words, acc_sentences)) return loss rep = input('Load pretrained RGC-ENCODER-DDDQN? (y or n): ') if rep == 'y' or rep == '': encoder.load('RGC/Encoder') dddqn.load('RGC/DDDQN') rep = input('Train RGC-ENCODER-DDDQN? (y or n): ') if rep == 'y' or rep == '': optimizer = tf.train.AdamOptimizer() # training loop over epoch and batchs for epoch in range(300): verbose = True for x, y, sl in zip(x_batch, y_parrot_batch, sl_batch): sos = dc.get_sos_batch_size(len(x)) optimizer.minimize(lambda: get_loss(encoder, dddqn, epoch, x, y, sl, sos, dc.max_tokens, verbose=verbose)) verbose = False encoder.save(name='RGC/Encoder') dddqn.save(name='RGC/DDDQN') acc = pu.get_acc_full_dataset(dc, encoder, dddqn) logging.info('Validation accuracy = {}'.format(acc)) if acc > 0.95: logging.info('Stopping criteria on validation accuracy raised') break return encoder, dddqn, dc
def parrot_initialization_encoder_decoder(dataset, emb_path, attention): ''' Trains the encoder-decoder to reproduce the input ''' dc = DataContainer(dataset, emb_path) dc.prepare_data() x_batch, y_parrot_batch, sl_batch = u.to_batch(dc.x, dc.y_parrot_padded, dc.sl, batch_size=dc.batch_size) def get_loss(encoder, decoder, epoch, x, y, sl, sos): output, cell_state = encoder.forward(x, sl) loss = decoder.get_loss(epoch, sos, (cell_state, output), y, sl, x, encoder.outputs) return loss if os.path.isdir('models/Encoder-Decoder'): rep = input('Load previously trained Encoder-Decoder? (y or n): ') if rep == 'y' or rep == '': encoder = EncoderRNN() decoder = DecoderRNN(dc.word2idx, dc.idx2word, dc.idx2emb, max_tokens=dc.max_tokens, attention=attention) encoder.load(name='Encoder-Decoder/Encoder') decoder.load(name='Encoder-Decoder/Decoder') sos = dc.get_sos_batch_size(len(dc.x)) see_parrot_results(encoder, decoder, 'final', dc.x, dc.y_parrot_padded, dc.sl, sos, greedy=True) else: encoder, decoder = choose_coders(dc, attention, search_size=5) else: encoder, decoder = choose_coders(dc, attention, search_size=5) optimizer = tf.train.AdamOptimizer() for epoch in range(300): for x, y, sl in zip(x_batch, y_parrot_batch, sl_batch): sos = dc.get_sos_batch_size(len(x)) # grad_n_vars = optimizer.compute_gradients(lambda: get_loss(encoder, decoder, epoch, x, y, sl, sos)) # optimizer.apply_gradients(grad_n_vars) optimizer.minimize(lambda: get_loss(encoder, decoder, epoch, x, y, sl, sos)) if epoch % 30 == 0: # to reduce training time, compute global accuracy only every 30 epochs sos = dc.get_sos_batch_size(len(dc.x)) see_parrot_results(encoder, decoder, epoch, dc.x, dc.y_parrot_padded, dc.sl, sos, greedy=True) # see_parrot_results(encoder, decoder, epoch, dc.x, dc.y_parrot_padded, dc.sl, sos) encoder.save(name='Encoder-Decoder/Encoder') decoder.save(name='Encoder-Decoder/Decoder') if decoder.parrot_stopping: break # x_batch, y_parrot_batch, sl_batch = u.shuffle_data(x_batch, y_parrot_batch, sl_batch) # strangely, shuffle data between epoch make the training realy noisy return encoder, decoder, dc
def choose_coders(dc, attention, search_size=8): ''' Trains search_size coders and return the best one ''' encoder = EncoderRNN() decoder = DecoderRNN(dc.word2idx, dc.idx2word, dc.idx2emb, max_tokens=dc.max_tokens, attention=attention) logging.info('Choosing coders...') logger = logging.getLogger() logger.disabled = True results_encoder = u.multiples_launch(pretrain_rnn_layer, [encoder, encoder.encoder_cell, dc], num_process=search_size) results_decoder = u.multiples_launch(pretrain_rnn_layer, [decoder, decoder.decoder_cell, dc], num_process=search_size) logger.disabled = False results_encoder.sort(key=lambda x: x[0], reverse=True) results_decoder.sort(key=lambda x: x[0], reverse=True) logging.info('Accuracy of the best encoder = {}'.format(results_encoder[0][0])) encoder.load(name='{}-{}'.format(encoder.name, results_encoder[0][1])) logging.info('Accuracy of the best decoder = {}'.format(results_decoder[0][0])) decoder.load(name='{}-{}'.format(decoder.name, results_decoder[0][1]), only_lstm=True) return encoder, decoder