コード例 #1
0
 def inverse(self, out_bij):
     """ irevnet inverse """
     out = split(out_bij)
     for i in range(len(self.stack)):
         out = self.stack[-1-i].inverse(out)
     out = merge(out[0],out[1])
     x = self.init_psi.inverse(out)
     return x
コード例 #2
0
 def forward(self, x):
     """ bijective or injective block forward """
     if self.pad != 0 and self.stride == 1:
         x = merge(x[0], x[1])
         x = self.inj_pad.forward(x)
         x1, x2 = split(x)
         x = (x1, x2)
     x1 = x[0]
     x2 = x[1]
     Fx2 = self.bottleneck_block(x2)
     if self.stride == 2:
         x1 = self.psi.forward(x1)
         x2 = self.psi.forward(x2)
     y1 = Fx2 + x1
     return (x2, y1)
コード例 #3
0
 def inverse(self, x):
     """ bijective or injecitve block inverse """
     x2, y1 = x[0], x[1]
     if self.stride == 2:
         x2 = self.psi.inverse(x2)
     Fx2 = -self.bottleneck_block(x2)
     x1 = Fx2 + y1
     if self.stride == 2:
         x1 = self.psi.inverse(x1)
     if self.pad != 0 and self.stride == 1:
         x = merge(x1, x2)
         x = self.inj_pad.inverse(x)
         x1, x2 = split(x)
         x = (x1, x2)
     else:
         x = (x1, x2)
     return x
コード例 #4
0
def main():
    # Parse
    parser = model_utils.get_parser()
    FLAGS, unparsed = parser.parse_known_args()
    # Setup model_dir
    if FLAGS.model_name is None:
        model_name = "LanguageModel"
    else:
        model_name = FLAGS.model_name

    model_dir = os.path.abspath(FLAGS.base_dir) + '/{}/'.format(model_name)
    if not os.path.exists(model_dir):
        model_utils.setup_model_dir(model_dir, create_base=True)
    if FLAGS.no_restore:
        model_utils.remove_history(model_dir)
        model_utils.setup_model_dir(model_dir, create_base=False)
    # Start logging
    logger = model_utils.get_logger(model_name, model_dir)
    logger.info("Started constructing {}".format(model_name))
    logger.info("Parsed args {}".format(FLAGS))
    if FLAGS.no_restore:
        logger.info('Not restoring, deleted history.')

    # Get Dataset
    logger.info("Getting dataset {}".format(FLAGS.dataset_name))
    full_dataset, tokenizer, size = data.make_dataset(FLAGS.dataset_name,
                                                      FLAGS.dataset_type,
                                                      FLAGS.data_dir,
                                                      FLAGS.seq_length)
    # Create model
    hparams = create_hparams(FLAGS.hparams)
    lm = LanguageModel(tokenizer.vocab_size, hparams.embedding_dim,
                       hparams.rnn_size, hparams.use_cudnn)
    optimizer = tf.train.AdamOptimizer(hparams.lr, hparams.beta1,
                                       hparams.beta2, hparams.epsilon)
    epoch_count = tf.Variable(1, 'epoch_count')
    global_step = tf.train.get_or_create_global_step()
    logger.info("Model created")
    # Create checkpointing
    checkpoint_dir = os.path.abspath(model_dir + 'ckpts/' + FLAGS.run_name)
    logger.info("Checkpoints at {}".format(checkpoint_dir))
    checkpoint_prefix = checkpoint_dir + '/ckpt'
    checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                     lm=lm,
                                     epoch_count=epoch_count,
                                     global_step=global_step)
    if not FLAGS.no_restore:
        if not FLAGS.load_checkpoint is None:
            load_checkpoint = FLAGS.load_checkpoint
        else:
            load_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
            logger.info("Loading latest checkpoint...")
        logger.info("Loading checkpoint {}".format(load_checkpoint))
        checkpoint.restore(load_checkpoint)

    # Create summary writer
    summary_dir = model_dir + 'log/' + FLAGS.run_name + '/'
    summary_writer = tf.contrib.summary.create_file_writer(summary_dir,
                                                           flush_millis=1000)

    # Training
    if FLAGS.mode == "train":
        logger.info("Beginning training...")
        device = '/gpu:0' if not FLAGS.no_gpu else '/cpu:0'
        # Get training Dataset
        logger.info("Full dataset size: {}".format(int(size)))
        logger.info("Train dataset size: {}".format(
            int(size * FLAGS.use_frac * FLAGS.train_frac)))
        train_dataset, valid_dataset = model_utils.split(
            full_dataset, size, FLAGS.use_frac, FLAGS.train_frac)
        train_dataset = train_dataset.batch(FLAGS.batch_size,
                                            drop_remainder=True)
        valid_dataset = valid_dataset.batch(FLAGS.batch_size,
                                            drop_remainder=True)
        train_dataset = (
            tf.data.experimental.prefetch_to_device(device)(train_dataset))
        valid_dataset = (
            tf.data.experimental.prefetch_to_device(device)(valid_dataset))
        # Train loop
        train_losses = []
        val_losses = []
        patience_count = 0
        for epoch in range(FLAGS.epochs):
            cur_epoch = epoch_count.numpy() + epoch
            logger.info("Starting epoch {}...".format(cur_epoch))
            start = time.time()
            with summary_writer.as_default():
                train_loss = lm.train(train_dataset, optimizer, global_step,
                                      FLAGS.log_interval)
                logger.info("Epoch {} complete: train loss = {:0.03f}".format(
                    cur_epoch, train_loss))
                logger.info("Validating...")
                val_loss = lm.evaluate(valid_dataset)
                logger.info("Validation loss = {:0.03f}".format(val_loss))
            time_elapsed = time.time() - start
            logger.info("Took {:0.01f} seconds".format(time_elapsed))
            # Checkpoint
            if FLAGS.early_stopping:
                if not val_losses or val_loss < min(
                        val_losses) - FLAGS.es_delta:
                    logger.info("Checkpointing...")
                    checkpoint.save(checkpoint_prefix)
                elif patience_count + 1 > FLAGS.patience:
                    logger.info("Early stopping reached")
                    break
                else:
                    patience_count += 1
            else:
                logger.info("Checkpointing...")
                checkpoint.save(checkpoint_prefix)

    elif FLAGS.mode == "eval":
        logger.info("Beginning evaluation...")
        device = '/gpu:0' if not FLAGS.no_gpu else '/cpu:0'
        with summary_writer.as_default():
            val_loss = lm.evaluate(full_dataset)
            logger.info("Validation loss: {:0.02f}".format(val_loss))

    elif FLAGS.mode == "generate":
        # Generate samples
        logger.info("Generating samples...")
        for _ in range(FLAGS.num_samples):
            tokens = tokenizer.tokenize(FLAGS.seed_text)
            inp = tf.constant(np.array(tokens, dtype=np.int16))
            inp = tf.expand_dims(inp, 0)
            _, state = lm.call_with_state(inp[:, 0:-1])  # Setup state
            cur_token = tokens[-1]
            done = False
            while not done:
                inp = tf.constant(np.array([cur_token], dtype=np.int16))
                inp = tf.expand_dims(inp, 0)
                logits, state = lm.call_with_state(inp, state)
                logits = tf.squeeze(logits, 0)
                logits = logits / FLAGS.temperature
                cur_token = tf.multinomial(logits, num_samples=1)[-1,
                                                                  0].numpy()
                tokens.append(cur_token)
                if len(tokens) > FLAGS.sample_length:
                    done = True
            logger.info("{}".format(tokenizer.untokenize(tokens)))
            lm.recurrent.reset_states()
コード例 #5
0
 def inverse(self,x):
     out = split(x)
     for i in range(len(self.block_list)):
         out = self.stack[-1 - i].inverse(out)
     out = merge(out[0], out[1])
     return out