Exemple #1
0
    def build_new_encoder_decoder_pair(self, num_new_classes=1):
        updated_latent_size = 2*self.latent_size \
                              + self.num_discrete \
                              + num_new_classes

        if self.encoder_model.layer_type is not 'cnn':
            # increase the number of latent params
            # if self.submodel % 4 == 0:  # XXX
            #     print 'adding extra layer...'
            #     layer_sizes = self.encoder_model.sizes + [512]
            # else:
            #     layer_sizes = self.encoder_model.sizes
            layer_sizes = self.encoder_model.sizes

            encoder = DenseEncoder(self.sess,
                                   updated_latent_size,
                                   self.is_training,
                                   scope="encoder",
                                   sizes=layer_sizes,
                                   use_ln=self.encoder_model.use_ln,
                                   use_bn=self.encoder_model.use_bn)
            is_dec_doubled = self.decoder_model.double_features > 1
            decoder = DenseEncoder(self.sess,
                                   self.input_size,
                                   self.is_training,
                                   scope="decoder",
                                   sizes=layer_sizes,
                                   double_features=is_dec_doubled,
                                   use_ln=self.decoder_model.use_ln,
                                   use_bn=self.decoder_model.use_bn)
        else:
            encoder = CNNEncoder(
                self.sess,
                updated_latent_size,
                self.is_training,
                scope="encoder",
                use_ln=self.encoder_model.use_ln,
                use_bn=self.encoder_model.use_bn,
            )
            decoder = CNNDecoder(
                self.sess,
                scope="decoder",
                double_channels=self.decoder_model.double_channels,
                input_size=self.input_size,
                is_training=self.is_training,
                use_ln=self.decoder_model.use_ln,
                use_bn=self.decoder_model.use_bn)

        return encoder, decoder
Exemple #2
0
def build_Nd_vae(sess,
                 source,
                 input_shape,
                 latent_size,
                 batch_size,
                 epochs=100):
    base_name = os.path.join(FLAGS.base_dir, "experiment")
    print 'base_name = ', base_name
    current_model = _find_latest_experiment_number(base_name)
    if current_model != -1:
        print "\nWARNING: old experiment found, but restoring is currently bugged, training new..\n"
        base_name = base_name + "_%d" % (current_model + 1)
        latest_model = (None, 0)
        # base_name = base_name + "_%d" % current_model
        # latest_model = find_latest_file("%s/models" % base_name, "vae(\d+)")
    else:
        base_name = _build_latest_base_dir(base_name)
        latest_model = (None, 0)

    print 'base name: ', base_name, '| latest model = ', latest_model

    # our placeholders are generated externally
    is_training = tf.placeholder(tf.bool)
    x = tf.placeholder(tf.float32,
                       shape=[FLAGS.batch_size] + list(input_shape),
                       name="input_placeholder")

    # build encoder and decoder models
    # note: these can be externally built
    #       as long as it works with forward()
    latent_size = 2*FLAGS.latent_size + 1 if FLAGS.sequential \
                  else 2*FLAGS.latent_size
    encoder = CNNEncoder(sess,
                         latent_size,
                         is_training,
                         use_ln=FLAGS.use_ln,
                         use_bn=FLAGS.use_bn)
    # decoder_latent_size = FLAGS.latent_size + 1 if FLAGS.sequential \
    #                       else FLAGS.latent_size
    decoder = CNNDecoder(sess,
                         input_size=input_shape,
                         is_training=is_training,
                         double_channels=True,
                         use_ln=FLAGS.use_ln,
                         use_bn=FLAGS.use_bn)
    print 'encoder = ', encoder.get_info()
    print 'decoder = ', decoder.get_info()

    # build the vae object
    VAEObj = VAE if FLAGS.sequential else VanillaVAE
    vae = VAEObj(sess,
                 x,
                 input_size=input_shape,
                 batch_size=FLAGS.batch_size,
                 latent_size=FLAGS.latent_size,
                 discrete_size=1,
                 p_x_given_z_func=distributions.Logistic,
                 encoder=encoder,
                 decoder=decoder,
                 is_training=is_training,
                 learning_rate=FLAGS.learning_rate,
                 submodel=latest_model[1],
                 img_shape=[32, 32, 3],
                 vae_tm1=None,
                 base_dir=base_name,
                 mutual_info_reg=FLAGS.mutual_info_reg)

    model_filename = "%s/models/%s" % (base_name, latest_model[0])
    is_forked = False

    if os.path.isfile(model_filename):
        vae.restore()
    else:
        sess.run([
            tf.global_variables_initializer(),
            tf.local_variables_initializer()
        ])

        # contain all the losses for runs
        mean_loss = []
        mean_elbo = []
        mean_recon = []
        mean_latent = []

        try:
            if not FLAGS.sequential:
                vae.train(source[0],
                          batch_size,
                          display_step=1,
                          training_epochs=epochs)
                mean_t, mean_recon_t, mean_latent_t, _, _, _ \
                    = evaluate_reconstr_loss_svhn(sess, vae,
                                                     batch_size)
                mean_loss += [mean_t]
                mean_latent += [mean_latent_t]
                mean_recon += [mean_recon_t]
            else:
                current_model = 0
                total_iter = 0
                all_models = [(current_model, source[current_model].number)]

                while True:
                    # fork if we get a new model
                    prev_model = current_model

                    # test our model every 100 iterations
                    if total_iter % 200 == 0:
                        vae.test(TEST_SET, batch_size)

                    # data iterator
                    inputs, outputs, indexes, current_model \
                        = generate_train_data(source,
                                              batch_size,
                                              batch_size,
                                              current_model)

                    # Distribution shift Swapping logic
                    if prev_model != current_model:
                        # save away the current test set loss
                        mean_t, mean_elbo_t, mean_recon_t, mean_latent_t, \
                            _, _, _, _\
                            = evaluate_reconstr_loss_svhn(sess, vae,
                                                             batch_size)
                        mean_loss += [mean_t]
                        mean_elbo += [mean_elbo_t]
                        mean_latent += [mean_latent_t]
                        mean_recon += [mean_recon_t]

                        # for the purposes of this experiment we end
                        # if we reach max_dist_swaps
                        if len(all_models) >= FLAGS.max_dist_swaps:
                            print '\ntrained %d models, exiting\n' \
                                % FLAGS.max_dist_swaps
                            break

                        # add a new discrete index if we haven't seen this distr yet
                        # if we compress, we just check to see if the "true" number is in the set
                        if FLAGS.compress_rotations:
                            # current_model_perms = set([current_model] + [i for i in range(len(source))])
                            all_true_models = [i[1] for i in all_models]
                            num_new_class = 1 if source[
                                current_model].number not in all_true_models else 0
                            print 'detected %s, prev = %s, num_new_class = %d' % (
                                str(source[current_model].number),
                                str(all_true_models), num_new_class)
                        else:
                            # just dont add dupes based on the number
                            all_models_index = [i[0] for i in all_models]
                            num_new_class = 1 if current_model not in all_models_index else 0

                        vae = vae.fork(num_new_class)
                        is_forked = True  # holds the first fork has been done [spawn student]

                        # keep track of all models (and the TRUE model)
                        # this is separated because the true model
                        # might not be the same (eg: rotations)
                        all_models.append(
                            (current_model, source[current_model].number))

                    for start, end in zip(
                            range(0,
                                  len(inputs) + 1, batch_size),
                            range(batch_size,
                                  len(inputs) + 1, batch_size)):
                        x = inputs[start:end]
                        loss, elbo, rloss, lloss = vae.partial_fit(
                            x, is_forked=is_forked)
                        print 'loss[total_iter=%d][iter=%d][model=%d] = %f, elbo loss = %f, latent loss = %f, reconstr loss = %f' \
                            % (total_iter, vae.iteration, current_model, loss, elbo, lloss,
                               rloss if rloss is not None else 0.0)

                    total_iter += 1

        except KeyboardInterrupt:
            print "caught keyboard exception..."

        vae.save()
        if FLAGS.sequential:
            np.savetxt("%s/models/class_list.csv" % vae.base_dir,
                       all_models,
                       delimiter=",")
            print 'All seen models: ', all_models

        write_all_losses(vae.base_dir,
                         mean_loss,
                         mean_elbo,
                         mean_recon,
                         mean_latent,
                         prefix="")

    return vae
Exemple #3
0
def main():
    data_folder = sys.argv[1]
    src_lang = sys.argv[2]
    ref_lang = sys.argv[3]
    conf = CNNConfig(data_folder, src_lang, ref_lang)

    src_vocab = pickle.load(open(sys.argv[4], 'rb'))
    ref_vocab = pickle.load(open(sys.argv[5], 'rb'))

    train_dataset = NMTDataset(load_data(conf.train_src_path),
                               load_data(conf.train_ref_path), src_vocab,
                               ref_vocab)
    train_dataloader = NMTDataLoader(train_dataset,
                                     batch_size=conf.batch_size,
                                     num_workers=0,
                                     shuffle=True)
    print('%d training dataset loaded.' % len(train_dataset))

    dev_dataset = NMTDataset(load_data(conf.dev_src_path),
                             load_data(conf.dev_ref_path), src_vocab,
                             ref_vocab)
    dev_dataloader = NMTDataLoader(dev_dataset,
                                   batch_size=conf.batch_size,
                                   num_workers=0)
    print('%d validation dataset loaded.' % len(dev_dataset))

    save_name = conf.save_path + '/%scnn' % ('w' if conf.word_level else '')
    start = 0
    if os.path.exists(save_name + '_encoder_' + str(start - 1)):
        encoder = torch.load(save_name + '_encoder_' + str(start - 1))
        decoder = torch.load(save_name + '_decoder_' + str(start - 1))
    else:
        encoder = CNNEncoder(conf.encoder_emb_size, conf.vocab_sizes,
                             train_dataset.src_vocab, conf.encoder_kernels,
                             len(conf.decoder_kernels), conf.encoder_dropout,
                             conf.word_level)
        decoder = CNNDecoder(conf.decoder_emb_size,
                             len(train_dataset.ref_vocab),
                             conf.decoder_kernels, conf.decoder_dropout)

    if conf.cuda:
        encoder.cuda()
        decoder.cuda()

    best_bleu = -1
    for epoch in range(start, start + conf.epochs):
        print('Epoch [{:3d}]'.format(epoch))
        train_loss = train(encoder, decoder, train_dataloader, conf)
        print('Training loss:\t%f' % train_loss)

        bleus = 0
        for _, (src, ref, cand, bleu) in enumerate(
                evaluate_cnn(encoder, decoder, dev_dataloader, conf.beam)):
            bleus += sum(bleu)
        bleus /= len(dev_dataloader.dataset)

        print('Avg BLEU score:{:8.4f}'.format(bleus))

        if bleus > best_bleu:
            #best_bleu = bleus
            torch.save(encoder.cpu(), save_name + '_encoder_' + str(epoch))
            torch.save(decoder.cpu(), save_name + '_decoder_' + str(epoch))
            if conf.cuda:
                encoder.cuda()
                decoder.cuda()