def test_once(vertex_idx, uv, basis, output, saver, sess):
    fetches = {"output": output}

    test_indices = args.INVERSE_NEARBY_INDICES[vertex_idx]

    restore_model(sess, saver,
                  os.path.join(args.checkpoint, "log_cluster_%d" % vertex_idx),
                  args)

    for test_idx in tqdm(test_indices):
        if args.real_indices is not None:
            _uv, _basis = load_data_from_file(
                args.test_indices_mapper[test_idx])
        else:
            _uv, _basis = load_data_from_file(test_idx)

        img = sess.run(fetches, feed_dict={uv: _uv, basis: _basis})["output"]

        img = img[0]

        weight = None
        for i in range(3):
            if vertex_idx == args.NEARBY_INDICES[test_idx][i]:
                weight = args.WEIGHTS[test_idx][i]
                break

        if test_idx not in args.IMGS:
            args.IMGS[test_idx] = img * weight
        else:
            args.IMGS[test_idx] += img * weight
Exemple #2
0
def train(gitapp: controller.GetInputTargetAndPredictedParameters):
    """Train a model."""
    g = tf.Graph()
    with g.as_default():
        total_loss_op, _, _ = total_loss(gitapp)

        if FLAGS.optimizer == OPTIMIZER_MOMENTUM:
            # TODO(ericmc): We may want to do weight decay with the other
            # optimizers, too.
            learning_rate = tf.train.exponential_decay(
                FLAGS.learning_rate,
                slim.variables.get_global_step(),
                FLAGS.learning_decay_steps,
                0.999,
                staircase=False)
            tf.summary.scalar('learning_rate', learning_rate)

            optimizer = tf.train.MomentumOptimizer(learning_rate, 0.875)
        elif FLAGS.optimizer == OPTIMIZER_ADAGRAD:
            optimizer = tf.train.AdagradOptimizer(FLAGS.learning_rate)
        elif FLAGS.optimizer == OPTIMIZER_ADAM:
            optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
        else:
            raise NotImplementedError('Unsupported optimizer: %s' %
                                      FLAGS.optimizer)

        # Set up training.
        train_op = slim.learning.create_train_op(total_loss_op,
                                                 optimizer,
                                                 summarize_gradients=True)

        if FLAGS.restore_directory:
            init_fn = util.restore_model(FLAGS.restore_directory,
                                         FLAGS.restore_logits)

        else:
            logging.info('Training a new model.')
            init_fn = None

        total_variable_size, _ = slim.model_analyzer.analyze_vars(
            slim.get_variables(), print_info=True)
        logging.info('Total number of variables: %d', total_variable_size)

        log_entry_points(g)

        slim.learning.train(
            train_op=train_op,
            logdir=output_directory(),
            master=FLAGS.master,
            is_chief=FLAGS.task == 0,
            number_of_steps=None,
            save_summaries_secs=FLAGS.save_summaries_secs,
            save_interval_secs=FLAGS.save_interval_secs,
            init_fn=init_fn,
            saver=tf.train.Saver(keep_checkpoint_every_n_hours=2.0))
        var_list=[
            var for var in tf.global_variables()
            if 'Adam' not in var.name and 'train_op' not in var.name
        ],
        save_relative_paths=True)

    if args.summaryDir is None:
        args.summaryDir = args.logDir
    train_writer = tf.summary.FileWriter(args.summaryDir, sess.graph)

    # initial variables
    init_op = tf.global_variables_initializer()
    sess.run(init_op)

    if args.checkpoint is not None:
        restore_model(sess, saver, args.checkpoint)

    logger.info('---Start training...---')

    for step in tqdm(range(0, args.max_steps), file=sys.stdout):

        def should(freq):
            return freq > 0 and (step % freq == 0
                                 or step == args.max_steps - 1)

        fetches = {'loss': model.loss, 'output': model.output}

        if should(args.display_freq):
            fetches["summary"] = model.summary_op

        fetches["train_op"] = model.train_op
Exemple #4
0
def run_training_reinforce(hps, retrain_model_hps, eval_hps, model_train, model_decode, batcher, neg_batcher, vocab):
    train_dir = os.path.join(FLAGS.log_root, 'train')
    
    if not os.path.exists(train_dir):
        os.makedirs(train_dir)

    config = util.get_config()
    default_device = tf.device('/cpu:0')

    with default_device():
        G_d = tf.Graph()
        G_t = tf.Graph()
        R = tf.Graph()

        with G_d.as_default():
            model_decode.build_graph()
            model_dvars = tf.get_collection(tf.GraphKeys.VARIABLES)

        with G_t.as_default():
            model_train.build_grapg()
            model_tvars = tf.get_collection(tf.GraphKeys.VARIABLES)

        ranker, rank_sess, _ = MultiFeedForwardClassifier.load(
            FLAGS.rank_log_root, graph=R, batch_size=FLAGS.batch_size)

        rank_world_dict, rank_embeddings = ioutils.load_embeddings(
            FLAGS.rank_embed_size, FLAGS.rank_vocab_file, 
            generate=False, load_extrac_from=FLAGS.rank_log_root, normalize=True)

        ranker.initialize_embeddings(rank_sess, rank_embeddings)

    decode_saver = tf.train.Saver(model_dvars)
    train_saver = tf.train.Saver(model_tvars, max_to_keep=10)
    model_ckpt_state = tf.train.get_checkpoint_state(train_dir)

    decode_sess = tf.Session(config=config, graph=G_d)
    train_sess = tf.Session(config=config, graph=G_t)
    train_saver = util.restore_model(
        train_sess, train_saver, model_ckpt_state.model_checkpoint_path, model_tvars, save=True)

    decode_saver = util.restore_model(
        decode_sess, decode_saver, model_ckpt_state.model_checkpoint_path, model_dvars)

    decoder = BeamSearchDecoder(model_decode, None, vocab, sess=decode_sess)

    train_step = 0

    hist_batch = []

    while True:
        batch = batcher.next_batch()
        neg_batch = neg_batcher.next_batch()
        pos_tg_len = np.sum(batch.padding_mask, axis=1)
        neg_tg_len = np.sum(neg_batch.padding_mask, axis=1)
        model_ckpt_state = train.train.get_checkpoint_state(train_dir)
        decode_saver.restore(decode_sess, model_ckpt_state.model_checkpoint_path)
        decoder = BeamSeachDecoder(model_decode, None, vocab, sess=decode_sess)
        pos_gen_out, pos_gen_out_extended_vocab, pos_all_hyp = decoder.sample_batch_wise(
            batch, sampling=FLAGS.sampling, preserve_hyp=True, temp_ratio=FLAGS.temp_ratio)
        neg_gen_out, neg_gen_out_extended_vocab, neg_all_hyp = decoder.sample_batch_wise(
            neg_batch, sampling=FLAGS.sampling, preserve_hyp=True, temp_ratio=FLAGS.temp_ratio)
        
        assert len(pos_gen_out) == hps.batch_size

        pos_decode_len = util.measure_len(pos_gen_out, vocab)
        neg_decode_len = util.measure_len(neg_gen_out, vocab)

        rank_src = util.convert_to_full_vocab(batch.enc_batch_extend_vocab, rank_word_dict, vocab, batch.art_oovs)
        rank_tg = util.convert_to_full_vocab(pos_gen_out_extended_vocabm rank_word_dict, vocab, batch.art_oovs)
        pos_sent_reward = ranker.eval(rank_sess, rank_src, rank_tg, batch.enc_lens, np.array(pos_decode_len)-1)

        rank_src = util.convert_to_full_vocab(
            neg_batch.enc_batch_extend_vocab, rank_word_dict, vocab, neg_batch.art_oovs)
        rank_tg = util.convert_to_full_vocab(
            neg_gen_out_extended_vocab, rank_word_dict, vocab, neg_batch.art_oovs)
        neg_sent_reward = ranker.eval(rank_sess, rank_src, rank_tg, neg_batch.enc_lens, np.array(neg_decode_len)-1)

        sent_reward = np.concatenate((pos_sent_reward, neg_sent_reward), axis=0)
        ###################################################################################
        # train G
        pos_decode_reward = cal_ranker_reward(
            batch, vocab, rank_word_dict, pos_all_hyp, decoder, ranker, rank_sess, pos_tg_len, rank_gt=FLAGS.rank_gt)
        neg_decode_reward = cal_ranker_reward(
            neg_batch, vocab, rank_word_dict, neg_all_hyp, decoder, ranker, rank_sess, neg_tg_len)
        decode_reward = np.concatenate((pos_decode_reward, neg_decode_reward), axis=0)

        msk = np.zeros_like(decode_reward)
        decode_len = np.concatenate((pos_decode_len, neg_decode_len), axis=0)
        
        for b, l in enumerate(decode_len):
            msk[b, :l] = 1

        rewards = decode_reward * msk / FLAGS.num_mc 

        sent_reward = util.rank_sentence(sent_reward, scale=FLAGS.sent_scale)
        rewards = util.rescale_reward(rewards, sent_reward, msk, scale=FLAGS.token_scale)

        rewards = rewards * msk

        dec_inp, target, padding = util.prepare_retrain_data(
            hps, pos_gen_out, pos_gen_out_extended_vocab, vocab, pos_decode_len)
        neg_dec_inp, neg_target, neg_padding = util.prepare_retrain_data(
            hps, neg_gen_out, neg_gen_out_extended_vocab, vocab, neg_decode_len)

        dec_inp = np.concatenate((dec_inp, dec_inp, neg_dec_inp), axis=0)
        target = np.concatenate((target, target, neg_target), axis=0)
        padding = np.concatenate((padding, padding, neg_padding), axis=0)
        rewards = np.concatenate((
            rewards[:FLAGS.batch_size], 
            np.ones(FLAGS.batch_size), 
            rewards[FLAGS.batch_size:]
            ))
        rslts = model_train.run_train_step_with_reward(
            train_sess, batch, dec_inp, target, padding, rewards, 
            neg_batch=neg_batch, temp_ratio=FLAGS.temp_ratio)

        train_step = rslts['global_step']
        train_saver.save(
            train_sess, 
            os.path.join(train_dir, 'model.ckpt'),
            global_step=train_step
            )

        if train_step & 50 == 0:
            train_saver.save(
                train_sess, os.path.join(save_dir, 'model.ckpt'),
                global_step=train_step
                )

            eval_loss = 0
            eval_batcher = SimpleBatcher(FLAGS.eval_data_path, vocab, eval_hps)
            eval_batcher_sz = SimpleBatcher(FLAGS.eval_data_path, vocab, eval_hps)

            eval_saver.restore(eval_sess, model_ckpt_state.model_checkpoint_path)
            for _ in range(eval_batcher.num_batch):
                eval_batch = eval_batcher.next_batch()
                eval_loss += model_eval.run_eval_step(eval_sess, eval_batch)['loss']

            score = 0
            sample_score = 0
            rouge_score = 0
            rouge2_score = 0
            rougel_score = 0
            bleu_score = 0

            for j in range(eval_batcher_sz.num_batch):
                eval_batch = eval_batcher_sz.next_batch()
                tg_len = np.sum(eval_batch.padding_mask, axis=1)
                _, sample_out_extended_vocab = decoder.sample_batch_wise(
                    eval_batch, sampling=FLAGS.rank_sampling, temp_ratio=FLAGS.sample_temp_ratio)
                sample_len = util.measure_len(sample_out_extended_vocab, vocab)
                gen_out, gen_out_extended_vocab = decoder.sample_batch_wise
Exemple #5
0
def infer(
    gitapp: controller.GetInputTargetAndPredictedParameters,
    restore_directory: str,
    output_directory: str,
    extract_patch_size: int,
    stitch_stride: int,
    infer_size: int,
    channel_whitelist: Optional[List[str]],
    simplify_error_panels: bool,
):
    """Runs inference on an image.

  Args:
    gitapp: GetInputTargetAndPredictedParameters.
    restore_directory: Where to restore the model from.
    output_directory: Where to write the generated images.
    extract_patch_size: The size of input to the model.
    stitch_stride: The stride size when running model inference.
      Equivalently, the output size of the model.
    infer_size: The number of simultaneous inferences to perform in the
      row and column dimensions.
      For example, if this is 8, inference will be performed in 8 x 8 blocks
      for a batch size of 64.
    channel_whitelist: If provided, only images for the given channels will
      be produced.
      This can be used to create simpler error panels.
    simplify_error_panels: Whether to create simplified error panels.

  Raises:
    ValueError: If
      1) The DataParameters don't contain a ReadPNGsParameters.
      2) The images must be larger than the input to the network.
      3) The graph must not contain queues.
  """
    rpp = gitapp.dp.io_parameters
    if not isinstance(rpp, data_provider.ReadPNGsParameters):
        raise ValueError(
            'Data provider must contain a ReadPNGsParameter, but was: %r',
            gitapp.dp)

    original_crop_size = rpp.crop_size
    image_num_rows, image_num_columns = util.image_size(rpp.directory)
    logging.info('Uncropped image size is %d x %d', image_num_rows,
                 image_num_columns)
    image_num_rows = min(image_num_rows, original_crop_size)
    if image_num_rows < extract_patch_size:
        raise ValueError(
            'Image is too small for inference to be performed: %d vs %d',
            image_num_rows, extract_patch_size)
    image_num_columns = min(image_num_columns, original_crop_size)
    if image_num_columns < extract_patch_size:
        raise ValueError(
            'Image is too small for inference to be performed: %d vs %d',
            image_num_columns, extract_patch_size)
    logging.info('After cropping, input image size is (%d, %d)',
                 image_num_rows, image_num_columns)

    num_row_inferences = (image_num_rows -
                          extract_patch_size) // (stitch_stride * infer_size)
    num_column_inferences = (image_num_columns - extract_patch_size) // (
        stitch_stride * infer_size)
    logging.info('Running %d x %d inferences', num_row_inferences,
                 num_column_inferences)
    num_output_rows = (num_row_inferences * infer_size * stitch_stride)
    num_output_columns = (num_column_inferences * infer_size * stitch_stride)
    logging.info('Output image size is (%d, %d)', num_output_rows,
                 num_output_columns)

    g = tf.Graph()
    with g.as_default():
        row_start = tf.placeholder(dtype=np.int32, shape=[])
        column_start = tf.placeholder(dtype=np.int32, shape=[])
        # Replace the parameters with a new set, which will cause the network to
        # run inference in just a local region.
        gitapp = gitapp._replace(dp=gitapp.dp._replace(
            io_parameters=rpp._replace(
                row_start=row_start,
                column_start=column_start,
                crop_size=(infer_size - 1) * stitch_stride +
                extract_patch_size,
            )))

        visualization_lts = controller.setup_stitch(gitapp)

        def get_statistics(tensor):
            rc = lt.ReshapeCoder(list(tensor.axes.keys())[:-1], ['batch'])
            return rc.decode(ops.distribution_statistics(rc.encode(tensor)))

        visualize_input_lt = visualization_lts['input']
        visualize_predict_input_lt = get_statistics(
            visualization_lts['predict_input'])
        visualize_target_lt = visualization_lts['target']
        visualize_predict_target_lt = get_statistics(
            visualization_lts['predict_target'])

        input_lt = lt.LabeledTensor(tf.placeholder(
            dtype=np.float32,
            shape=[
                1, num_output_rows, num_output_columns,
                len(gitapp.dp.input_z_values), 1, 2
            ]),
                                    axes=[
                                        'batch',
                                        'row',
                                        'column',
                                        ('z', gitapp.dp.input_z_values),
                                        ('channel', ['TRANSMISSION']),
                                        ('mask', [False, True]),
                                    ])
        predict_input_lt = lt.LabeledTensor(
            tf.placeholder(
                dtype=np.float32,
                shape=[
                    1,
                    num_output_rows,
                    num_output_columns,
                    len(gitapp.dp.input_z_values),
                    1,
                    len(visualize_predict_input_lt.axes['statistic']),
                ]),
            axes=[
                'batch',
                'row',
                'column',
                ('z', gitapp.dp.input_z_values),
                ('channel', ['TRANSMISSION']),
                visualize_predict_input_lt.axes['statistic'],
            ])
        input_error_panel_lt = visualize.error_panel_from_statistics(
            input_lt, predict_input_lt, simplify_error_panels)

        target_lt = lt.LabeledTensor(
            tf.placeholder(dtype=np.float32,
                           shape=[
                               1, num_output_rows, num_output_columns,
                               len(gitapp.dp.target_z_values),
                               len(gitapp.dp.target_channel_values) + 1, 2
                           ]),
            axes=[
                'batch',
                'row',
                'column',
                ('z', gitapp.dp.target_z_values),
                ('channel',
                 gitapp.dp.target_channel_values + ['NEURITE_CONFOCAL']),
                ('mask', [False, True]),
            ])
        predict_target_lt = lt.LabeledTensor(
            tf.placeholder(
                dtype=np.float32,
                shape=[
                    1,
                    num_output_rows,
                    num_output_columns,
                    len(gitapp.dp.target_z_values),
                    len(gitapp.dp.target_channel_values) + 1,
                    len(visualize_predict_target_lt.axes['statistic']),
                ]),
            axes=[
                'batch',
                'row',
                'column',
                ('z', gitapp.dp.target_z_values),
                ('channel',
                 gitapp.dp.target_channel_values + ['NEURITE_CONFOCAL']),
                visualize_predict_target_lt.axes['statistic'],
            ])

        logging.info('input_lt: %r', input_lt)
        logging.info('predict_input_lt: %r', predict_input_lt)
        logging.info('target_lt: %r', target_lt)
        logging.info('predict_target_lt: %r', predict_target_lt)

        def select_channels(tensor):
            if channel_whitelist is not None:
                return lt.select(tensor, {'channel': channel_whitelist})
            else:
                return tensor

        target_error_panel_lt = visualize.error_panel_from_statistics(
            select_channels(target_lt), select_channels(predict_target_lt),
            simplify_error_panels)

        # There shouldn't be any queues in this configuration.
        queue_runners = g.get_collection(tf.GraphKeys.QUEUE_RUNNERS)
        if queue_runners:
            raise ValueError('Graph must not have queues, but had: %r',
                             queue_runners)

        logging.info('Attempting to find restore checkpoint in %s',
                     restore_directory)
        init_fn = util.restore_model(restore_directory,
                                     restore_logits=True,
                                     restore_global_step=True)

        with tf.Session() as sess:
            logging.info('Generating images')
            init_fn(sess)

            input_rows = []
            predict_input_rows = []
            target_rows = []
            predict_target_rows = []
            for infer_row in range(num_row_inferences):
                input_row = []
                predict_input_row = []
                target_row = []
                predict_target_row = []
                for infer_column in range(num_column_inferences):
                    rs = infer_row * infer_size * stitch_stride
                    cs = infer_column * infer_size * stitch_stride
                    logging.info('Running inference at offset: (%d, %d)', rs,
                                 cs)
                    [inpt, predict_input, target,
                     predict_target] = sess.run([
                         visualize_input_lt,
                         visualize_predict_input_lt,
                         visualize_target_lt,
                         visualize_predict_target_lt,
                     ],
                                                feed_dict={
                                                    row_start: rs,
                                                    column_start: cs
                                                })

                    input_row.append(inpt)
                    predict_input_row.append(predict_input)
                    target_row.append(target)
                    predict_target_row.append(predict_target)
                input_rows.append(np.concatenate(input_row, axis=2))
                predict_input_rows.append(
                    np.concatenate(predict_input_row, axis=2))
                target_rows.append(np.concatenate(target_row, axis=2))
                predict_target_rows.append(
                    np.concatenate(predict_target_row, axis=2))

            logging.info('Stitching')
            stitched_input = np.concatenate(input_rows, axis=1)
            stitched_predict_input = np.concatenate(predict_input_rows, axis=1)
            stitched_target = np.concatenate(target_rows, axis=1)
            stitched_predict_target = np.concatenate(predict_target_rows,
                                                     axis=1)

            logging.info('Creating error panels')
            [input_error_panel, target_error_panel, global_step] = sess.run(
                [
                    input_error_panel_lt, target_error_panel_lt,
                    tf.train.get_global_step()
                ],
                feed_dict={
                    input_lt: stitched_input,
                    predict_input_lt: stitched_predict_input,
                    target_lt: stitched_target,
                    predict_target_lt: stitched_predict_target,
                })

            output_directory = os.path.join(output_directory,
                                            '%.8d' % global_step)
            if not gfile.Exists(output_directory):
                gfile.MakeDirs(output_directory)

            util.write_image(
                os.path.join(output_directory, 'input_error_panel.png'),
                input_error_panel[0, :, :, :])
            util.write_image(
                os.path.join(output_directory, 'target_error_panel.png'),
                target_error_panel[0, :, :, :])

            logging.info('Done generating images')