Exemplo n.º 1
0
def parrot_initialization_encoder_decoder(dataset, emb_path, attention):
  '''
  Trains the encoder-decoder to reproduce the input
  '''
  dc = DataContainer(dataset, emb_path)
  dc.prepare_data()

  x_batch, y_parrot_batch, sl_batch = u.to_batch(dc.x, dc.y_parrot_padded, dc.sl, batch_size=dc.batch_size)

  def get_loss(encoder, decoder, epoch, x, y, sl, sos):
    output, cell_state = encoder.forward(x, sl)
    loss = decoder.get_loss(epoch, sos, (cell_state, output), y, sl, x, encoder.outputs)
    return loss

  if os.path.isdir('models/Encoder-Decoder'):
    rep = input('Load previously trained Encoder-Decoder? (y or n): ')
    if rep == 'y' or rep == '':
      encoder = EncoderRNN()
      decoder = DecoderRNN(dc.word2idx, dc.idx2word, dc.idx2emb, max_tokens=dc.max_tokens, attention=attention)
      encoder.load(name='Encoder-Decoder/Encoder')
      decoder.load(name='Encoder-Decoder/Decoder')
      sos = dc.get_sos_batch_size(len(dc.x))
      see_parrot_results(encoder, decoder, 'final', dc.x, dc.y_parrot_padded, dc.sl, sos, greedy=True)
    else:
      encoder, decoder = choose_coders(dc, attention, search_size=5)
  else:
    encoder, decoder = choose_coders(dc, attention, search_size=5)

  optimizer = tf.train.AdamOptimizer()

  for epoch in range(300):
    for x, y, sl in zip(x_batch, y_parrot_batch, sl_batch):
      sos = dc.get_sos_batch_size(len(x))
      # grad_n_vars = optimizer.compute_gradients(lambda: get_loss(encoder, decoder, epoch, x, y, sl, sos))
      # optimizer.apply_gradients(grad_n_vars)
      optimizer.minimize(lambda: get_loss(encoder, decoder, epoch, x, y, sl, sos))
    if epoch % 30 == 0:
      # to reduce training time, compute global accuracy only every 30 epochs
      sos = dc.get_sos_batch_size(len(dc.x))
      see_parrot_results(encoder, decoder, epoch, dc.x, dc.y_parrot_padded, dc.sl, sos, greedy=True)
      # see_parrot_results(encoder, decoder, epoch, dc.x, dc.y_parrot_padded, dc.sl, sos)
    encoder.save(name='Encoder-Decoder/Encoder')
    decoder.save(name='Encoder-Decoder/Decoder')
    if decoder.parrot_stopping:
      break
    # x_batch, y_parrot_batch, sl_batch = u.shuffle_data(x_batch, y_parrot_batch, sl_batch)
    # strangely, shuffle data between epoch make the training realy noisy

  return encoder, decoder, dc
Exemplo n.º 2
0
def choose_coders(dc, attention, search_size=8):
  '''
  Trains search_size coders and return the best one
  '''
  encoder = EncoderRNN()
  decoder = DecoderRNN(dc.word2idx, dc.idx2word, dc.idx2emb, max_tokens=dc.max_tokens, attention=attention)

  logging.info('Choosing coders...')
  logger = logging.getLogger()
  logger.disabled = True

  results_encoder = u.multiples_launch(pretrain_rnn_layer, [encoder, encoder.encoder_cell, dc], num_process=search_size)
  results_decoder = u.multiples_launch(pretrain_rnn_layer, [decoder, decoder.decoder_cell, dc], num_process=search_size)

  logger.disabled = False

  results_encoder.sort(key=lambda x: x[0], reverse=True)
  results_decoder.sort(key=lambda x: x[0], reverse=True)
  logging.info('Accuracy of the best encoder = {}'.format(results_encoder[0][0]))
  encoder.load(name='{}-{}'.format(encoder.name, results_encoder[0][1]))
  logging.info('Accuracy of the best decoder = {}'.format(results_decoder[0][0]))
  decoder.load(name='{}-{}'.format(decoder.name, results_decoder[0][1]), only_lstm=True)
  return encoder, decoder