示例#1
0
def bert_embed(input_ids, token_type_ids=None, position_ids=None, vocab_size=30522, embed_dim=768,
               num_pos_ids=512, dropout_prob=0.1, test=True):
    """Construct the embeddings from word, position and token type."""

    batch_size = input_ids.shape[0]
    seq_len = input_ids.shape[1]
    if position_ids is None:
        position_ids = F.arange(0, seq_len)
        position_ids = F.broadcast(F.reshape(
            position_ids, (1,)+position_ids.shape), (batch_size,) + position_ids.shape)
    if token_type_ids is None:
        token_type_ids = F.constant(val=0, shape=(batch_size, seq_len))

    embeddings = PF.embed(input_ids, vocab_size,
                          embed_dim, name='word_embeddings')
    position_embeddings = PF.embed(
        position_ids, num_pos_ids, embed_dim, name='position_embeddings')
    token_type_embeddings = PF.embed(
        token_type_ids, 2, embed_dim, name='token_type_embeddings')

    embeddings += position_embeddings
    embeddings += token_type_embeddings
    embeddings = PF.layer_normalization(
        embeddings, batch_axis=(0, 1), eps=1e-12, name='embed')

    if dropout_prob > 0.0 and not test:
        embeddings = F.dropout(embeddings, dropout_prob)

    return embeddings
示例#2
0
def predict(x):
    with nn.auto_forward():
        x = x.reshape((1, sentence_length_source))
        enc_input = nn.Variable.from_numpy_array(x)
        mask = get_mask(enc_input)
        enc_input = time_distributed(PF.embed)(enc_input,
                                               vocab_size_source,
                                               embedding_size,
                                               name='enc_embeddings') * mask

        # encoder
        with nn.parameter_scope('encoder'):
            enc_output, c, h = lstm(enc_input,
                                    hidden,
                                    mask=mask,
                                    return_sequences=True,
                                    return_state=True)

        # decode
        pad = nn.Variable.from_numpy_array(np.array([w2i_target['<bos>']]))
        x = PF.embed(pad,
                     vocab_size_target,
                     embedding_size,
                     name='dec_embeddings')

        _cell, _hidden = c, h

        word_index = 0
        ret = []
        i = 0
        while i2w_target[word_index] != '。' and i < 20:
            with nn.parameter_scope('decoder'):
                with nn.parameter_scope('lstm'):
                    _cell, _hidden = lstm_cell(x, _cell, _hidden)
                    q = F.reshape(_hidden, (1, 1, hidden))
                    attention_output = global_attention(q,
                                                        enc_output,
                                                        mask=mask,
                                                        score='dot')
            attention_output = F.reshape(attention_output, (1, hidden))
            output = F.concatenate(_hidden, attention_output, axis=1)
            output = PF.affine(output, vocab_size_target, name='output')

            word_index = np.argmax(output.d[0])
            ret.append(word_index)
            x = nn.Variable.from_numpy_array(
                np.array([word_index], dtype=np.int32))
            x = PF.embed(x,
                         vocab_size_target,
                         embedding_size,
                         name='dec_embeddings')

            i += 1

        return ret
示例#3
0
def cnn(batch_size, vocab_size, text_len, classes, features=128, train=True):
    text = nn.Variable([batch_size, text_len])

    with nn.parameter_scope("text_embed"):
        embed = PF.embed(text, n_inputs=vocab_size, n_features=features)
    print("embed", embed.shape)

    embed = F.reshape(embed, (batch_size, 1, text_len, features))
    print("embed", embed.shape)

    combined = None
    for n in range(2, 6): # 2 - 5 gram
        with nn.parameter_scope(str(n) + "_gram"):
            with nn.parameter_scope("conv"):
                conv = PF.convolution(embed, 128, kernel=(n, features))
                conv = F.relu(conv)
            with nn.parameter_scope("pool"):
                pool = F.max_pooling(conv, kernel=(conv.shape[2], 1))
                if not combined:
                    combined = F.identity(pool)
                else:
                    combined = F.concatenate(combined, pool)

    if train:
        combined = F.dropout(combined, 0.5)

    with nn.parameter_scope("output"):
        y = PF.affine(combined, classes)

    t = nn.Variable([batch_size, 1])

    _loss = F.softmax_cross_entropy(y, t)
    loss = F.reduce_mean(_loss)

    return text, y, loss, t
示例#4
0
def get_loss(l1,
             l2,
             x,
             t,
             w_init,
             b_init,
             num_words,
             batch_size,
             state_size,
             dropout=False,
             dropout_rate=0.5,
             embed_name='embed',
             pred_name='pred'):
    e_list = [
        PF.embed(x_elm, num_words, state_size, name=embed_name)
        for x_elm in F.split(x, axis=1)
    ]
    t_list = F.split(t, axis=1)
    loss = 0
    for i, (e_t, t_t) in enumerate(zip(e_list, t_list)):
        if dropout:
            h1 = l1(F.dropout(e_t, dropout_rate), w_init, b_init)
            h2 = l2(F.dropout(h1, dropout_rate), w_init, b_init)
            y = PF.affine(F.dropout(h2, dropout_rate),
                          num_words,
                          name=pred_name)
        else:
            h1 = l1(e_t, w_init, b_init)
            h2 = l2(h1, w_init, b_init)
            y = PF.affine(h2, num_words, name=pred_name)
        t_t = F.reshape(t_t, [batch_size, 1])
        loss += F.mean(F.softmax_cross_entropy(y, t_t))
    loss /= float(i + 1)

    return loss
示例#5
0
    def call(self, inputs):
        r"""
        Args:
            inputs (nn.Variable): An input variable of shape (B, T).

        Returns:
            nn.Variable: Output variable of shape (T, B, C).
        """
        # inputs of shape (B, T)
        hparams = self._hparams
        embedded_inputs = PF.embed(inputs,
                                   n_inputs=len(hparams.vocab),
                                   n_features=hparams.symbols_embedding_dim,
                                   initializer=NormalInitializer(0.3),
                                   name='embedding')  # (B, T, C)

        prenet_outputs = prenet(embedded_inputs,
                                layer_sizes=hparams.prenet_channels,
                                is_training=self.training,
                                scope='prenet_encoder')  # (B, T, C)

        encoder_outputs = encoder_cbhg(F.transpose(prenet_outputs, (0, 2, 1)),
                                       depth=hparams.encoder_embedding_dim,
                                       is_training=self.training)  # (T, B, C)

        return encoder_outputs
    def test_embeddings(self):
        x = nn.Variable((2, ))
        x.d = [0, 1]
        with nn.parameter_scope("embeddings"):
            output = embedding(x, 2, 3)
            output.forward()

        with nn.parameter_scope("embeddings"), nn.auto_forward():
            output_ref = PF.embed(x, 2, 3)
            self.assertTrue(np.allclose(output.d, output_ref.d))
示例#7
0
    def call(self, inputs):
        r"""Encoder layer.
        Args:
            inputs (nn.Variable): An input variable of shape (B, T) indicates indices
                of character embeddings.

        Returns:
            nn.Variable: Output variable of shape (T, B, C).
        """
        hp = self._hparams
        with nn.parameter_scope('embeddings'):
            val = np.sqrt(6.0 / (len(hp.vocab) + hp.symbols_embedding_dim))
            inputs = PF.embed(
                inputs,
                n_inputs=len(hp.vocab),
                n_features=hp.symbols_embedding_dim,
                initializer=UniformInitializer(lim=(-val,
                                                    val)))  # (B, T, C=512)

        with nn.parameter_scope('ngrams'):
            out = inputs
            for i in range(hp.encoder_n_convolutions):
                with nn.parameter_scope(f'filter_{i}'):
                    out = conv_norm(out,
                                    out_channels=hp.encoder_embedding_dim,
                                    kernel_size=hp.encoder_kernel_size,
                                    padding=(hp.encoder_kernel_size - 1) // 2,
                                    bias=False,
                                    stride=1,
                                    dilation=1,
                                    w_init_gain='relu',
                                    scope='conv_norm',
                                    channel_last=True)  # (B, C=512, T)
                    out = PF.batch_normalization(out,
                                                 batch_stat=self.training,
                                                 axes=[2])
                    out = F.relu(out)
                    if self.training:
                        # (B, C=512, T) --> (B, T, C=512)
                        out = F.dropout(out, 0.5)

        with nn.parameter_scope('lstm_encoder'):
            out = F.transpose(out, (1, 0, 2))  # (2, 0, 1))
            h = F.constant(shape=(2, 2, hp.batch_size,
                                  hp.encoder_embedding_dim // 2))
            c = F.constant(shape=(2, 2, hp.batch_size,
                                  hp.encoder_embedding_dim // 2))
            out, _, _ = PF.lstm(out,
                                h,
                                c,
                                training=self.training,
                                bidirectional=True)

        return out  # (T, B, C=512)
示例#8
0
 def test_embed_inverse(self):
     with nn.parameter_scope("embed_inverse"):
         x = nn.Variable((3, ))
         x.d[0] = 0
         x.d[1] = 1
         x.d[2] = 2
         with nn.auto_forward():
             embed = PF.embed(x, 3, 10)
             x = nn.Variable((3, 10))
             x2 = embed_inverse(embed, 3, 10)
         self.assertEqual(x2.shape, (3, 3))
def embedding(x, input_dim, output_dim, init=None, mask_zero=False):
    if init is None:
        init = I.UniformInitializer((-0.1, 0.1))
    initialized = "embed/W" in nn.get_parameters()
    result = PF.embed(x, input_dim, output_dim)
    if not initialized:
        nn.get_parameters()["embed/W"].d = init(
            nn.get_parameters()["embed/W"].shape)

    if mask_zero:
        return result, 1 - F.equal_scalar(x, 0)
    else:
        return result
示例#10
0
def fast_text(batch_size, vocab_size, text_len, classes, features, train=True):
    text = nn.Variable([batch_size, text_len])

    with nn.parameter_scope("text_embed"):
        embed = PF.embed(text, n_inputs=vocab_size, n_features=features)

    avg = F.mean(embed, axis=1)

    with nn.parameter_scope("output"):
        y = PF.affine(avg, classes)

    t = nn.Variable([batch_size, 1])

    _loss = F.softmax_cross_entropy(y, t)
    loss = F.reduce_mean(_loss)

    return text, y, loss, t
示例#11
0
def build_model(train=True, get_embeddings=False):
    x = nn.Variable((batch_size, sentence_length, ptb_dataset.word_length))
    mask = expand_dims(F.sign(x), axis=-1)
    t = nn.Variable((batch_size, sentence_length))

    with nn.parameter_scope('char_embedding'):
        h = PF.embed(x, char_vocab_size, char_embedding_dim) * mask
    h = F.transpose(h, (0, 3, 1, 2))
    output = []
    for f, f_size in zip(filters, filster_sizes):
        _h = PF.convolution(h, f, kernel=(1, f_size), pad=(0, f_size//2), name='conv_{}'.format(f_size))
        _h = F.max_pooling(_h, kernel=(1, ptb_dataset.word_length))
        output.append(_h)
    h = F.concatenate(*output, axis=1)
    h = F.transpose(h, (0, 2, 1, 3))

    mask = get_mask(F.sum(x, axis=2))
    embeddings = F.reshape(h, (batch_size, sentence_length, sum(filters))) * mask

    if get_embeddings:
        return x, embeddings

    with nn.parameter_scope('highway1'):
        h = time_distributed(highway)(embeddings)
    with nn.parameter_scope('highway2'):
        h = time_distributed(highway)(h)
    with nn.parameter_scope('lstm1'):
        h = lstm(h, lstm_size, mask=mask, return_sequences=True)
    with nn.parameter_scope('lstm2'):
        h = lstm(h, lstm_size, mask=mask, return_sequences=True)
    with nn.parameter_scope('hidden'):
        h = F.relu(time_distributed(PF.affine)(h, lstm_size))
    if train:
        h = F.dropout(h, p=dropout_ratio)
    with nn.parameter_scope('output'):
        y = time_distributed(PF.affine)(h, word_vocab_size)

    mask = F.sign(t) # do not predict 'pad'.
    entropy = time_distributed_softmax_cross_entropy(y, expand_dims(t, axis=-1)) * mask
    count = F.sum(mask, axis=1)
    loss = F.mean(F.div2(F.sum(entropy, axis=1), count))
    return x, t, loss
示例#12
0
            self.params[key].data = projection(self.params[key].data,
                                               eps=self.eps)


def loss_function(u, v, negative_samples):
    return F.sum(-F.log(
        F.exp(-distance(u, v)) / sum([
            F.exp(-distance(u, x)) for x in F.split(negative_samples, axis=2)
        ])))


u = nn.Variable((batch_size, ))
v = nn.Variable((batch_size, ))
negative_samples = nn.Variable((batch_size, negative_sample_size))

_u = PF.embed(u, vocab_size, embedding_size)
_v = PF.embed(v, vocab_size, embedding_size)
_neg = PF.embed(negative_samples, vocab_size, embedding_size)
_neg = F.transpose(_neg, axes=(0, 2, 1))

loss = loss_function(_u, _v, _neg)

nn.get_parameters()["embed/W"].d = I.UniformInitializer(
    [-0.01, 0.01])(shape=(vocab_size, embedding_size))

solver = RiemannianSgd(lr=0.1)
solver.set_parameters(nn.get_parameters())

trainer = Trainer(inputs=[u, v, negative_samples], loss=loss, solver=solver)
trainer.run(train_data_iter, None, epochs=max_epoch)

train_data_iter = data_iterator_simple(load_train_func,
                                       len(x_train),
                                       batch_size,
                                       shuffle=True,
                                       with_file_cache=False)
valid_data_iter = data_iterator_simple(load_valid_func,
                                       len(x_valid),
                                       batch_size,
                                       shuffle=True,
                                       with_file_cache=False)

x = nn.Variable([batch_size, window_size * 2])
with nn.parameter_scope('W_in'):
    h = PF.embed(x, vocab_size, embedding_size)
h = F.mean(h, axis=1)
h = expand_dims(h, axis=-1)  # (batch_size, embedding_size, 1)
t = nn.Variable([batch_size, 1])
t_neg = nn.Variable([batch_size, k])
with nn.parameter_scope('W_out'):
    _t = PF.embed(t, vocab_size,
                  embedding_size)  # (batch_size, 1, embedding_size)
    _t_neg = PF.embed(t_neg, vocab_size,
                      embedding_size)  # (batch_size, k, embedding_size)

t_score = F.sigmoid(F.reshape(F.batch_matmul(_t, h), shape=(batch_size, 1)))
t_neg_score = F.sigmoid(
    F.reshape(F.batch_matmul(_t_neg, h), shape=(batch_size, k)))

t_loss = F.binary_cross_entropy(t_score, F.constant(1, shape=(batch_size, 1)))
示例#14
0
def train():
    args = get_args()

    # Set context.
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in {}:{}".format(args.context, args.type_config))
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    data_iterator = data_iterator_librispeech(args.batch_size, args.data_dir)
    _data_source = data_iterator._data_source  # dirty hack...

    # model
    x = nn.Variable(
        shape=(args.batch_size, data_config.duration, 1))  # (B, T, 1)
    onehot = F.one_hot(x, shape=(data_config.q_bit_len, ))  # (B, T, C)
    wavenet_input = F.transpose(onehot, (0, 2, 1))  # (B, C, T)

    # speaker embedding
    if args.use_speaker_id:
        s_id = nn.Variable(shape=(args.batch_size, 1))
        with nn.parameter_scope("speaker_embedding"):
            s_emb = PF.embed(s_id, n_inputs=_data_source.n_speaker,
                             n_features=WavenetConfig.speaker_dims)
            s_emb = F.transpose(s_emb, (0, 2, 1))
    else:
        s_emb = None

    net = WaveNet()
    wavenet_output = net(wavenet_input, s_emb)

    pred = F.transpose(wavenet_output, (0, 2, 1))

    # (B, T, 1)
    t = nn.Variable(shape=(args.batch_size, data_config.duration, 1))

    loss = F.mean(F.softmax_cross_entropy(pred, t))

    # for generation
    prob = F.softmax(pred)

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)

    # setup save env.
    audio_save_path = os.path.join(os.path.abspath(
        args.model_save_path), "audio_results")
    if audio_save_path and not os.path.exists(audio_save_path):
        os.makedirs(audio_save_path)

    # Training loop.
    for i in range(args.max_iter):
        # todo: validation

        x.d, _speaker, t.d = data_iterator.next()
        if args.use_speaker_id:
            s_id.d = _speaker.reshape(-1, 1)

        solver.zero_grad()
        loss.forward(clear_no_need_grad=True)
        loss.backward(clear_buffer=True)
        solver.update()

        loss.data.cast(np.float32, ctx)
        monitor_loss.add(i, loss.d.copy())

        if i % args.model_save_interval == 0:
            prob.forward()
            audios = mu_law_decode(
                np.argmax(prob.d, axis=-1), quantize=data_config.q_bit_len)  # (B, T)
            save_audio(audios, i, audio_save_path)
示例#15
0
def train_nerf(config, comm, model, dataset='blender'):

    use_transient = False
    use_embedding = False

    if model == 'wild':
        use_transient = True
        use_embedding = True
    elif model == 'uncertainty':
        use_transient = True
    elif model == 'appearance':
        use_embedding = True

    save_results_dir = config.log.save_results_dir
    os.makedirs(save_results_dir, exist_ok=True)

    train_loss_dict = {
        'train_coarse_loss': 0.0,
        'train_fine_loss': 0.0,
        'train_total_loss': 0.0,
    }

    test_metric_dict = {'test_loss': 0.0, 'test_psnr': 0.0}

    monitor_manager = MonitorManager(train_loss_dict, test_metric_dict,
                                     save_results_dir)

    if dataset != 'phototourism':
        images, poses, _, hwf, i_test, i_train, near_plane, far_plane = get_data(
            config)
        height, width, focal_length = hwf
    else:
        di = get_photo_tourism_dataiterator(config, 'train', comm)
        val_di = get_photo_tourism_dataiterator(config, 'val', comm)

    if model != 'vanilla':
        if dataset != 'phototourism':
            config.train.n_vocab = max(np.max(i_train), np.max(i_test)) + 1
        print(
            f'Setting Vocabulary size of embedding as {config.train.n_vocab}')

    if dataset != 'phototourism':
        if model in ['vanilla']:
            if comm is not None:
                # uncomment the following line to test on fewer images
                i_test = i_test[3 * comm.rank:3 * (comm.rank + 1)]
                pass
            else:
                # uncomment the following line to test on fewer images
                i_test = i_test[:3]
                pass
        else:
            # i_test = i_train[0:5]
            i_test = [i * (comm.rank + 1) for i in range(5)]
    else:
        i_test = [1]

    encode_position_function = get_encoding_function(
        config.train.num_encodings_position, True, True)
    if config.train.use_view_directions:
        encode_direction_function = get_encoding_function(
            config.train.num_encodings_direction, True, True)
    else:
        encode_direction_function = None

    lr = config.solver.lr
    num_decay_steps = config.solver.lr_decay_step * 1000
    lr_decay_factor = config.solver.lr_decay_factor
    solver = S.Adam(alpha=lr)

    load_solver_state = False
    if config.checkpoint.param_path is not None:
        nn.load_parameters(config.checkpoint.param_path)
        load_solver_state = True

    if comm is not None:
        num_decay_steps /= comm.n_procs
        comm_size = comm.n_procs
    else:
        comm_size = 1
    pbar = trange(config.train.num_iterations // comm_size,
                  disable=(comm is not None and comm.rank > 0))

    for i in pbar:

        if dataset != 'phototourism':

            idx = np.random.choice(i_train)
            image = nn.Variable.from_numpy_array(images[idx][None, :, :, :3])
            pose = nn.Variable.from_numpy_array(poses[idx])

            ray_directions, ray_origins = get_ray_bundle(
                height, width, focal_length, pose)

            grid = get_direction_grid(width,
                                      height,
                                      focal_length,
                                      return_ij_2d_grid=True)
            grid = F.reshape(grid, (-1, 2))

            select_inds = np.random.choice(grid.shape[0],
                                           size=[config.train.num_rand_points],
                                           replace=False)
            select_inds = F.gather_nd(grid, select_inds[None, :])
            select_inds = F.transpose(select_inds, (1, 0))

            embed_inp = nn.Variable.from_numpy_array(
                np.full((config.train.chunksize_fine, ), idx, dtype=int))

            ray_origins = F.gather_nd(ray_origins, select_inds)
            ray_directions = F.gather_nd(ray_directions, select_inds)

            image = F.gather_nd(image[0], select_inds)

        else:
            rays, embed_inp, image = di.next()
            ray_origins = nn.Variable.from_numpy_array(rays[:, :3])
            ray_directions = nn.Variable.from_numpy_array(rays[:, 3:6])
            near_plane = nn.Variable.from_numpy_array(rays[:, 6])
            far_plane = nn.Variable.from_numpy_array(rays[:, 7])

            embed_inp = nn.Variable.from_numpy_array(embed_inp)
            image = nn.Variable.from_numpy_array(image)

            hwf = None

        app_emb, trans_emb = None, None
        if use_embedding:
            with nn.parameter_scope('embedding_a'):
                app_emb = PF.embed(embed_inp, config.train.n_vocab,
                                   config.train.n_app)

        if use_transient:
            with nn.parameter_scope('embedding_t'):
                trans_emb = PF.embed(embed_inp, config.train.n_vocab,
                                     config.train.n_trans)

        if use_transient:
            rgb_map_course, rgb_map_fine, static_rgb_map_fine, transient_rgb_map_fine, beta, static_sigma, transient_sigma = forward_pass(
                ray_directions,
                ray_origins,
                near_plane,
                far_plane,
                app_emb,
                trans_emb,
                encode_position_function,
                encode_direction_function,
                config,
                use_transient,
                hwf=hwf,
                image=image)
            course_loss = 0.5 * F.mean(F.squared_error(rgb_map_course, image))
            fine_loss = 0.5 * F.mean(
                F.squared_error(rgb_map_fine, image) /
                F.reshape(F.pow_scalar(beta, 2), beta.shape + (1, )))
            beta_reg_loss = 3 + F.mean(F.log(beta))
            sigma_trans_reg_loss = 0.01 * F.mean(transient_sigma)
            loss = course_loss + fine_loss + beta_reg_loss + sigma_trans_reg_loss
        else:
            rgb_map_course, _, _, _, rgb_map_fine, _, _, _ = forward_pass(
                ray_directions,
                ray_origins,
                near_plane,
                far_plane,
                app_emb,
                trans_emb,
                encode_position_function,
                encode_direction_function,
                config,
                use_transient,
                hwf=hwf)
            course_loss = F.mean(F.squared_error(rgb_map_course, image))
            fine_loss = F.mean(F.squared_error(rgb_map_fine, image))
            loss = course_loss + fine_loss

        pbar.set_description(
            f'Total: {np.around(loss.d, 4)}, Course: {np.around(course_loss.d, 4)}, Fine: {np.around(fine_loss.d, 4)}'
        )

        solver.set_parameters(nn.get_parameters(),
                              reset=False,
                              retain_state=True)
        if load_solver_state:
            solver.load_states(config['checkpoint']['solver_path'])
            load_solver_state = False

        solver.zero_grad()

        loss.backward(clear_buffer=True)

        # Exponential LR decay
        if dataset != 'phototourism':
            lr_factor = (lr_decay_factor**((i) / num_decay_steps))
            solver.set_learning_rate(lr * lr_factor)
        else:
            if i % num_decay_steps == 0 and i != 0:
                solver.set_learning_rate(lr * lr_decay_factor)

        if comm is not None:
            params = [x.grad for x in nn.get_parameters().values()]
            comm.all_reduce(params, division=False, inplace=True)
        solver.update()

        if ((i % config.train.save_interval == 0
             or i == config.train.num_iterations - 1)
                and i != 0) and (comm is not None and comm.rank == 0):
            nn.save_parameters(os.path.join(save_results_dir, f'iter_{i}.h5'))
            solver.save_states(
                os.path.join(save_results_dir, f'solver_iter_{i}.h5'))

        if (i % config.train.test_interval == 0
                or i == config.train.num_iterations - 1) and i != 0:
            avg_psnr, avg_mse = 0.0, 0.0
            for i_t in trange(len(i_test)):

                if dataset != 'phototourism':
                    idx_test = i_test[i_t]
                    image = nn.NdArray.from_numpy_array(
                        images[idx_test][None, :, :, :3])
                    pose = nn.NdArray.from_numpy_array(poses[idx_test])

                    ray_directions, ray_origins = get_ray_bundle(
                        height, width, focal_length, pose)

                    ray_directions = F.reshape(ray_directions, (-1, 3))
                    ray_origins = F.reshape(ray_origins, (-1, 3))

                    embed_inp = nn.NdArray.from_numpy_array(
                        np.full((config.train.chunksize_fine, ),
                                idx_test,
                                dtype=int))

                else:
                    rays, embed_inp, image = val_di.next()
                    ray_origins = nn.NdArray.from_numpy_array(rays[0, :, :3])
                    ray_directions = nn.NdArray.from_numpy_array(rays[0, :,
                                                                      3:6])
                    near_plane_ = nn.NdArray.from_numpy_array(rays[0, :, 6])
                    far_plane_ = nn.NdArray.from_numpy_array(rays[0, :, 7])

                    embed_inp = nn.NdArray.from_numpy_array(
                        embed_inp[0, :config.train.chunksize_fine])
                    image = nn.NdArray.from_numpy_array(image[0].transpose(
                        1, 2, 0))
                    image = F.reshape(image, (1, ) + image.shape)
                    idx_test = 1

                app_emb, trans_emb = None, None
                if use_embedding:
                    with nn.parameter_scope('embedding_a'):
                        app_emb = PF.embed(embed_inp, config.train.n_vocab,
                                           config.train.n_app)

                if use_transient:
                    with nn.parameter_scope('embedding_t'):
                        trans_emb = PF.embed(embed_inp, config.train.n_vocab,
                                             config.train.n_trans)

                num_ray_batches = ray_directions.shape[
                    0] // config.train.ray_batch_size + 1

                if use_transient:
                    rgb_map_fine_list, static_rgb_map_fine_list, transient_rgb_map_fine_list = [], [], []
                else:
                    rgb_map_fine_list, depth_map_fine_list = [], []

                for r_idx in trange(num_ray_batches):
                    if r_idx != num_ray_batches - 1:
                        ray_d, ray_o = ray_directions[
                            r_idx * config.train.ray_batch_size:(r_idx + 1) *
                            config.train.ray_batch_size], ray_origins[
                                r_idx *
                                config.train.ray_batch_size:(r_idx + 1) *
                                config.train.ray_batch_size]

                        if dataset == 'phototourism':
                            near_plane = near_plane_[
                                r_idx *
                                config.train.ray_batch_size:(r_idx + 1) *
                                config.train.ray_batch_size]
                            far_plane = far_plane_[r_idx *
                                                   config.train.ray_batch_size:
                                                   (r_idx + 1) *
                                                   config.train.ray_batch_size]

                    else:
                        if ray_directions.shape[0] - (
                                num_ray_batches -
                                1) * config.train.ray_batch_size == 0:
                            break
                        ray_d, ray_o = ray_directions[
                            r_idx *
                            config.train.ray_batch_size:, :], ray_origins[
                                r_idx * config.train.ray_batch_size:, :]
                        if dataset == 'phototourism':
                            near_plane = near_plane_[r_idx * config.train.
                                                     ray_batch_size:]
                            far_plane = far_plane_[r_idx * config.train.
                                                   ray_batch_size:]

                    if use_transient:
                        rgb_map_course, rgb_map_fine, static_rgb_map_fine, transient_rgb_map_fine, beta, static_sigma, transient_sigma = forward_pass(
                            ray_d,
                            ray_o,
                            near_plane,
                            far_plane,
                            app_emb,
                            trans_emb,
                            encode_position_function,
                            encode_direction_function,
                            config,
                            use_transient,
                            hwf=hwf)

                        rgb_map_fine_list.append(rgb_map_fine)
                        static_rgb_map_fine_list.append(static_rgb_map_fine)
                        transient_rgb_map_fine_list.append(
                            transient_rgb_map_fine)

                    else:
                        _, _, _, _, rgb_map_fine, depth_map_fine, _, _ = \
                            forward_pass(ray_d, ray_o, near_plane, far_plane, app_emb, trans_emb,
                                         encode_position_function, encode_direction_function, config, use_transient, hwf=hwf)

                        rgb_map_fine_list.append(rgb_map_fine)
                        depth_map_fine_list.append(depth_map_fine)

                if use_transient:
                    rgb_map_fine = F.concatenate(*rgb_map_fine_list, axis=0)
                    static_rgb_map_fine = F.concatenate(
                        *static_rgb_map_fine_list, axis=0)
                    transient_rgb_map_fine = F.concatenate(
                        *transient_rgb_map_fine_list, axis=0)

                    rgb_map_fine = F.reshape(rgb_map_fine, image[0].shape)
                    static_rgb_map_fine = F.reshape(static_rgb_map_fine,
                                                    image[0].shape)
                    transient_rgb_map_fine = F.reshape(transient_rgb_map_fine,
                                                       image[0].shape)
                    static_trans_img_to_save = np.concatenate(
                        (static_rgb_map_fine.data,
                         np.ones((image[0].shape[0], 5, 3)),
                         transient_rgb_map_fine.data),
                        axis=1)
                    img_to_save = np.concatenate(
                        (image[0].data, np.ones(
                            (image[0].shape[0], 5, 3)), rgb_map_fine.data),
                        axis=1)
                else:

                    rgb_map_fine = F.concatenate(*rgb_map_fine_list, axis=0)
                    depth_map_fine = F.concatenate(*depth_map_fine_list,
                                                   axis=0)

                    rgb_map_fine = F.reshape(rgb_map_fine, image[0].shape)
                    depth_map_fine = F.reshape(depth_map_fine,
                                               image[0].shape[:-1])
                    img_to_save = np.concatenate(
                        (image[0].data, np.ones(
                            (image[0].shape[0], 5, 3)), rgb_map_fine.data),
                        axis=1)

                filename = os.path.join(save_results_dir,
                                        f'{i}_{idx_test}.png')
                try:
                    imsave(filename,
                           np.clip(img_to_save, 0, 1),
                           channel_first=False)
                    print(f'Saved generation at {filename}')
                    if use_transient:
                        filename_static_trans = os.path.join(
                            save_results_dir, f's_t_{i}_{idx_test}.png')
                        imsave(filename_static_trans,
                               np.clip(static_trans_img_to_save, 0, 1),
                               channel_first=False)

                    else:
                        filename_dm = os.path.join(save_results_dir,
                                                   f'dm_{i}_{idx_test}.png')
                        depth_map_fine = (depth_map_fine.data -
                                          depth_map_fine.data.min()) / (
                                              depth_map_fine.data.max() -
                                              depth_map_fine.data.min())
                        imsave(filename_dm,
                               depth_map_fine[:, :, None],
                               channel_first=False)
                        plt.imshow(depth_map_fine.data)
                        plt.savefig(filename_dm)
                        plt.close()
                except:
                    pass

                avg_mse += F.mean(F.squared_error(rgb_map_fine, image[0])).data
                avg_psnr += (-10. * np.log10(
                    F.mean(F.squared_error(rgb_map_fine, image[0])).data))

            test_metric_dict['test_loss'] = avg_mse / len(i_test)
            test_metric_dict['test_psnr'] = avg_psnr / len(i_test)
            monitor_manager.add(i, test_metric_dict)
            print(
                f'Saved generations after {i} training iterations! Average PSNR: {avg_psnr/len(i_test)}. Average MSE: {avg_mse/len(i_test)}'
            )
示例#16
0
def main():

    parser = argparse.ArgumentParser()

    parser.add_argument('--output-filename',
                        '-o',
                        type=str,
                        default='video.gif',
                        help="name of an output file.")
    parser.add_argument('--output-static-filename',
                        '-os',
                        type=str,
                        default='video_static.gif',
                        help="name of an output file.")
    parser.add_argument('--config-path',
                        '-c',
                        type=str,
                        default='configs/llff.yaml',
                        required=True,
                        help='model and training configuration file')
    parser.add_argument('--weight-path',
                        '-w',
                        type=str,
                        default='configs/llff.yaml',
                        required=True,
                        help='path to pretrained NeRF parameters')
    parser.add_argument(
        '--model',
        type=str,
        choices=['wild', 'uncertainty', 'appearance', 'vanilla'],
        required=True,
        help='Select the model to train')

    parser.add_argument('--visualization-type',
                        '-v',
                        type=str,
                        choices=['zoom', '360-rotation', 'default'],
                        default='default-render-poses',
                        help='type of visualization')

    parser.add_argument(
        '--downscale',
        '-d',
        default=1,
        type=float,
        help="downsampling factor for the rendered images for faster inference"
    )

    parser.add_argument(
        '--num-images',
        '-n',
        default=120,
        type=int,
        help="Number of images to generate for the output video/gif")

    parser.add_argument("--fast",
                        help="Use Fast NeRF architecture",
                        action="store_true")

    args = parser.parse_args()

    use_transient = False
    use_embedding = False

    if args.model == 'wild':
        use_transient = True
        use_embedding = True
    elif args.model == 'uncertainty':
        use_transient = True
    elif args.model == 'appearance':
        use_embedding = True

    args = parser.parse_args()
    config = read_yaml(args.config_path)

    config.data.downscale = args.downscale

    nn.set_auto_forward(True)
    ctx = get_extension_context('cuda')
    nn.set_default_context(ctx)
    nn.load_parameters(args.weight_path)

    _, _, render_poses, hwf, _, _, near_plane, far_plane = get_data(config)
    height, width, focal_length = hwf
    print(
        f'Rendering with Height {height}, Width {width}, Focal Length: {focal_length}'
    )

    # mapping_net = MLP
    encode_position_function = get_encoding_function(
        config.train.num_encodings_position, True, True)
    if config.train.use_view_directions:
        encode_direction_function = get_encoding_function(
            config.train.num_encodings_direction, True, True)
    else:
        encode_direction_function = None

    frames = []
    if use_transient:
        static_frames = []

    if args.visualization_type == '360-rotation':
        print('The 360 degree roation result will not work with LLFF data!')
        pbar = tqdm(np.linspace(0, 360, args.num_images, endpoint=False))
    elif args.visualization_type == 'zoom':
        pbar = tqdm(
            np.linspace(near_plane, far_plane, args.num_images,
                        endpoint=False))
    else:
        args.num_images = min(args.num_images, render_poses.shape[0])
        pbar = tqdm(
            np.arange(0, render_poses.shape[0],
                      render_poses.shape[0] // args.num_images))

    print(f'Rendering {args.num_images} poses...')

    for th in pbar:

        if args.visualization_type == '360-rotation':
            pose = nn.NdArray.from_numpy_array(pose_spherical(th, -30., 4.))
        elif args.visualization_type == 'zoom':
            pose = nn.NdArray.from_numpy_array(trans_t(th))
        else:
            pose = nn.NdArray.from_numpy_array(render_poses[th][:3, :4])
            # pose = nn.NdArray.from_numpy_array(render_poses[0][:3, :4])

        ray_directions, ray_origins = get_ray_bundle(height, width,
                                                     focal_length, pose)

        ray_directions = F.reshape(ray_directions, (-1, 3))
        ray_origins = F.reshape(ray_origins, (-1, 3))

        num_ray_batches = ray_directions.shape[
            0] // config.train.ray_batch_size + 1

        app_emb, trans_emb = None, None
        if use_embedding:
            with nn.parameter_scope('embedding_a'):
                embed_inp = nn.NdArray.from_numpy_array(
                    np.full((config.train.chunksize_fine, ), 1, dtype=int))
                app_emb = PF.embed(embed_inp, config.train.n_vocab,
                                   config.train.n_app)

        if use_transient:
            with nn.parameter_scope('embedding_t'):
                embed_inp = nn.NdArray.from_numpy_array(
                    np.full((config.train.chunksize_fine, ), th, dtype=int))
                trans_emb = PF.embed(embed_inp, config.train.n_vocab,
                                     config.train.n_trans)

            static_rgb_map_fine_list, transient_rgb_map_fine_list = [], []

        rgb_map_fine_list = []

        for i in trange(num_ray_batches):
            if i != num_ray_batches - 1:
                ray_d, ray_o = ray_directions[i * config.train.ray_batch_size:(
                    i + 1) * config.train.ray_batch_size], ray_origins[
                        i * config.train.ray_batch_size:(i + 1) *
                        config.train.ray_batch_size]
            else:
                ray_d, ray_o = ray_directions[
                    i * config.train.ray_batch_size:, :], ray_origins[
                        i * config.train.ray_batch_size:, :]

            if use_transient:
                _, rgb_map_fine, static_rgb_map_fine, transient_rgb_map_fine, _, _, _ = forward_pass(
                    ray_d,
                    ray_o,
                    near_plane,
                    far_plane,
                    app_emb,
                    trans_emb,
                    encode_position_function,
                    encode_direction_function,
                    config,
                    use_transient,
                    hwf=hwf,
                    fast=args.fast)

                static_rgb_map_fine_list.append(static_rgb_map_fine)
                transient_rgb_map_fine_list.append(transient_rgb_map_fine)

            else:
                _, _, _, _, rgb_map_fine, _, _, _ = \
                    forward_pass(ray_d, ray_o, near_plane, far_plane, app_emb, trans_emb, encode_position_function,
                                 encode_direction_function, config, use_transient, hwf=hwf, fast=args.fast)
            rgb_map_fine_list.append(rgb_map_fine)

        rgb_map_fine = F.concatenate(*rgb_map_fine_list, axis=0)
        rgb_map_fine = F.reshape(rgb_map_fine, (height, width, 3))

        if use_transient:
            static_rgb_map_fine = F.concatenate(*static_rgb_map_fine_list,
                                                axis=0)
            static_rgb_map_fine = F.reshape(static_rgb_map_fine,
                                            (height, width, 3))

        frames.append(
            (255 * np.clip(rgb_map_fine.data, 0, 1)).astype(np.uint8))
        if use_transient:
            static_frames.append(
                (255 * np.clip(static_rgb_map_fine.data, 0, 1)).astype(
                    np.uint8))

    imageio.mimwrite(args.output_filename, frames, fps=30)
    if use_transient:
        imageio.mimwrite(args.output_static_filename, static_frames, fps=30)
示例#17
0
def LSTMAttentionDecoder(inputs=None,
                         encoder_output=None,
                         initial_state=None,
                         return_sequences=False,
                         return_state=False,
                         inference_params=None,
                         name='lstm'):

    if inputs is None:
        assert inference_params is not None, 'if inputs is None, inference_params must not be None.'
    else:
        sentence_length = inputs.shape[1]

    assert type(initial_state) is tuple or type(initial_state) is list, \
           'initial_state must be a typle or a list.'
    assert len(initial_state) == 2, \
           'initial_state must have only two states.'

    c0, h0 = initial_state

    assert c0.shape == h0.shape, 'shapes of initial_state must be same.'
    batch_size, units = c0.shape

    cell = c0
    hidden = h0

    hs = []

    if inference_params is None:
        xs = F.split(F.slice(inputs,
                             stop=(batch_size, sentence_length - 1, units)),
                     axis=1)
        pad = nn.Variable.from_numpy_array(
            np.array([w2i_source['pad']] * batch_size))
        xs = [
            PF.embed(
                pad, vocab_size_source, embedding_size, name='enc_embeddings')
        ] + list(xs)

        compute_context = GlobalAttention(encoder_output, 1024)

        for x in xs:
            with nn.parameter_scope(name):
                cell, hidden = lstm_cell(x, cell, hidden)
                context = compute_context(hidden)
                h_t = F.tanh(
                    PF.affine(F.concatenate(context, hidden, axis=1),
                              1024,
                              with_bias=False,
                              name='Wc'))
            hs.append(h_t)
    else:
        assert batch_size == 1, 'batch size of inference mode must be 1.'
        embed_weight, output_weight, output_bias = inference_params
        pad = nn.Variable.from_numpy_array(
            np.array([w2i_source['pad']] * batch_size))
        x = PF.embed(pad,
                     vocab_size_source,
                     embedding_size,
                     name='enc_embeddings')

        compute_context = GlobalAttention(encoder_output, 1024)

        word_index = 0
        ret = []
        i = 0
        while i2w_target[word_index] != '。' and i < 20:
            with nn.parameter_scope(name):
                cell, hidden = lstm_cell(x, cell, hidden)
                context = compute_context(hidden)
                h_t = F.tanh(
                    PF.affine(F.concatenate(context, hidden, axis=1),
                              1024,
                              with_bias=False,
                              name='Wc'))
            output = F.affine(h_t, output_weight, bias=output_bias)
            word_index = np.argmax(output.d[0])
            ret.append(word_index)
            x = nn.Variable.from_numpy_array(
                np.array([word_index], dtype=np.int32))
            x = F.embed(x, embed_weight)

            i += 1
        return ret

    if return_sequences:
        ret = F.stack(*hs, axis=1)
    else:
        ret = hs[-1]

    if return_state:
        return ret, cell, hidden
    else:
        return ret
示例#18
0
def main():

    parser = argparse.ArgumentParser()

    parser.add_argument('--output-filename',
                        '-o',
                        type=str,
                        default='video.gif',
                        help="name of an output file.")
    parser.add_argument('--output-static-filename',
                        '-os',
                        type=str,
                        default='video_static.gif',
                        help="name of an output file.")
    parser.add_argument('--config-path',
                        '-c',
                        type=str,
                        default='configs/llff.yaml',
                        required=True,
                        help='model and training configuration file')
    parser.add_argument('--weight-path',
                        '-w',
                        type=str,
                        default='configs/llff.yaml',
                        required=True,
                        help='path to pretrained NeRF parameters')
    parser.add_argument(
        '--model',
        type=str,
        choices=['wild', 'uncertainty', 'appearance', 'vanilla'],
        required=True,
        help='Select the model to train')

    parser.add_argument('--visualization-type',
                        '-v',
                        type=str,
                        choices=['zoom', '360-rotation', 'default'],
                        default='default-render-poses',
                        help='type of visualization')

    parser.add_argument(
        '--downscale',
        '-d',
        default=1,
        type=float,
        help="downsampling factor for the rendered images for faster inference"
    )

    parser.add_argument(
        '--num-images',
        '-n',
        default=120,
        type=int,
        help="Number of images to generate for the output video/gif")

    args = parser.parse_args()

    nn.set_auto_forward(True)
    comm = init_nnabla(ext_name="cudnn")

    use_transient = False
    use_embedding = False

    if args.model == 'wild':
        use_transient = True
        use_embedding = True
    elif args.model == 'uncertainty':
        use_transient = True
    elif args.model == 'appearance':
        use_embedding = True

    args = parser.parse_args()
    config = read_yaml(args.config_path)

    config.data.downscale = args.downscale
    nn.load_parameters(args.weight_path)

    data_source = get_photo_tourism_dataiterator(config, 'test', comm)

    # Pose, Appearance index for generating novel views
    # as well as camera trajectory is hard-coded here.
    data_source.test_appearance_idx = 125
    pose_idx = 125
    dx = np.linspace(-0.2, 0.15, args.num_images // 3)
    dy = -0.15
    dz = np.linspace(0.1, 0.22, args.num_images // 3)

    embed_idx_list = list(data_source.poses_dict.keys())

    data_source.poses_test = np.tile(data_source.poses_dict[pose_idx],
                                     (args.num_images, 1, 1))
    for i in range(0, args.num_images // 3):
        data_source.poses_test[i, 0, 3] += dx[i]
        data_source.poses_test[i, 1, 3] += dy
    for i in range(args.num_images // 3, args.num_images // 2):
        data_source.poses_test[i, 0, 3] += dx[len(dx) - 1 - i]
        data_source.poses_test[i, 1, 3] += dy

    for i in range(args.num_images // 2, 5 * args.num_images // 6):
        data_source.poses_test[i, 2, 3] += dz[i - args.num_images // 2]
        data_source.poses_test[i, 1, 3] += dy
        data_source.poses_test[i, 0, 3] += dx[len(dx) // 2]

    for i in range(5 * args.num_images // 6, args.num_images):
        data_source.poses_test[i, 2, 3] += dz[args.num_images - 1 - i]
        data_source.poses_test[i, 1, 3] += dy
        data_source.poses_test[i, 0, 3] += dx[len(dx) // 2]

    # mapping_net = MLP
    encode_position_function = get_encoding_function(
        config.train.num_encodings_position, True, True)
    if config.train.use_view_directions:
        encode_direction_function = get_encoding_function(
            config.train.num_encodings_direction, True, True)
    else:
        encode_direction_function = None

    frames = []
    if use_transient:
        static_frames = []

    pbar = tqdm(np.arange(0, data_source.poses_test.shape[0]))
    data_source._size = data_source.poses_test.shape[0]
    data_source.test_img_w = 400
    data_source.test_img_h = 400
    data_source.test_focal = data_source.test_img_w / 2 / np.tan(np.pi / 6)
    data_source.test_K = np.array(
        [[data_source.test_focal, 0, data_source.test_img_w / 2],
         [0, data_source.test_focal, data_source.test_img_h / 2], [0, 0, 1]])

    data_source._indexes = np.arange(0, data_source._size)

    di = data_iterator(data_source, batch_size=1)

    print(f'Rendering {args.num_images} poses...')

    a = [1, 128]
    alpha = np.linspace(0, 1, args.num_images)

    for th in pbar:

        rays, embed_inp = di.next()
        ray_origins = nn.NdArray.from_numpy_array(rays[0, :, :3])
        ray_directions = nn.NdArray.from_numpy_array(rays[0, :, 3:6])
        near_plane_ = nn.NdArray.from_numpy_array(rays[0, :, 6])
        far_plane_ = nn.NdArray.from_numpy_array(rays[0, :, 7])

        embed_inp = nn.NdArray.from_numpy_array(
            embed_inp[0, :config.train.chunksize_fine])
        image_shape = (data_source.test_img_w, data_source.test_img_h, 3)

        ray_directions = F.reshape(ray_directions, (-1, 3))
        ray_origins = F.reshape(ray_origins, (-1, 3))

        num_ray_batches = (ray_directions.shape[0] +
                           config.train.ray_batch_size -
                           1) // config.train.ray_batch_size

        app_emb, trans_emb = None, None
        if use_embedding:
            with nn.parameter_scope('embedding_a'):
                embed_inp_app = nn.NdArray.from_numpy_array(
                    np.full((config.train.chunksize_fine, ), a[0], dtype=int))
                app_emb = PF.embed(embed_inp_app, config.train.n_vocab,
                                   config.train.n_app)

                embed_inp_app = nn.NdArray.from_numpy_array(
                    np.full((config.train.chunksize_fine, ), a[1], dtype=int))
                app_emb_2 = PF.embed(embed_inp_app, config.train.n_vocab,
                                     config.train.n_app)

                app_emb = app_emb * alpha[th] + app_emb_2 * (1 - alpha[th])

        if use_transient:
            with nn.parameter_scope('embedding_t'):
                trans_emb = PF.embed(embed_inp, config.train.n_vocab,
                                     config.train.n_trans)

            static_rgb_map_fine_list, transient_rgb_map_fine_list = [], []

        rgb_map_fine_list = []

        for i in trange(num_ray_batches):
            ray_d, ray_o = ray_directions[i * config.train.ray_batch_size:(
                i + 1) * config.train.ray_batch_size], ray_origins[
                    i * config.train.ray_batch_size:(i + 1) *
                    config.train.ray_batch_size]

            near_plane = near_plane_[i * config.train.ray_batch_size:(i + 1) *
                                     config.train.ray_batch_size]
            far_plane = far_plane_[i * config.train.ray_batch_size:(i + 1) *
                                   config.train.ray_batch_size]

            if use_transient:
                _, rgb_map_fine, static_rgb_map_fine, transient_rgb_map_fine, _, _, _ = forward_pass(
                    ray_d, ray_o, near_plane, far_plane, app_emb, trans_emb,
                    encode_position_function, encode_direction_function,
                    config, use_transient)

                static_rgb_map_fine_list.append(static_rgb_map_fine)
                transient_rgb_map_fine_list.append(transient_rgb_map_fine)

            else:
                _, _, _, _, rgb_map_fine, _, _, _ = \
                    forward_pass(ray_d, ray_o, near_plane, far_plane, app_emb, trans_emb,
                                 encode_position_function, encode_direction_function, config, use_transient)

            rgb_map_fine_list.append(rgb_map_fine)

        rgb_map_fine = F.concatenate(*rgb_map_fine_list, axis=0)
        rgb_map_fine = F.reshape(rgb_map_fine, image_shape)

        if use_transient:
            static_rgb_map_fine = F.concatenate(*static_rgb_map_fine_list,
                                                axis=0)
            static_rgb_map_fine = F.reshape(static_rgb_map_fine, image_shape)

        frames.append(
            (255 * np.clip(rgb_map_fine.data, 0, 1)).astype(np.uint8))
        if use_transient:
            static_frames.append(
                (255 * np.clip(static_rgb_map_fine.data, 0, 1)).astype(
                    np.uint8))

    imageio.mimwrite(args.output_filename, frames, fps=30)
    imageio.mimwrite(args.output_static_filename, static_frames, fps=30)
示例#19
0
def main():

    # Get arguments
    args = get_args()
    data_file = "https://raw.githubusercontent.com/tomsercu/lstm/master/data/ptb.train.txt"
    model_file = args.work_dir + "model.h5"

    # Load Dataset
    itow, wtoi, dataset = load_ptbset(data_file)

    # Computation environment settings
    from nnabla.contrib.context import extension_context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = extension_context(extension_module, device_id=args.device_id)
    nn.set_default_context(ctx)

    # Create data provider
    n_word = len(wtoi)
    n_dim = args.embed_dim
    batchsize = args.batchsize
    half_window = args.half_window_length
    n_negative = args.n_negative_sample

    di = DataIteratorForEmbeddingLearning(
        batchsize=batchsize,
        half_window=half_window,
        n_negative=n_negative,
        dataset=dataset)

    # Create model
    # - Real batch size including context samples and negative samples
    size = batchsize * (1 + n_negative) * (2 * (half_window - 1))

    # Model for learning
    # - input variables
    xl = nn.Variable((size,))  # variable for word
    yl = nn.Variable((size,))  # variable for context

    # Embed layers for word embedding function
    # - f_embed : word index x to get y, the n_dim vector
    # --  for each sample in a minibatch
    hx = PF.embed(xl, n_word, n_dim, name="e1")  # feature vector for word
    hy = PF.embed(yl, n_word, n_dim, name="e1")  # feature vector for context
    hl = F.sum(hx * hy, axis=1)

    # -- Approximated likelihood of context prediction
    # pos: word context, neg negative samples
    tl = nn.Variable([size, ], need_grad=False)
    loss = F.sigmoid_cross_entropy(hl, tl)
    loss = F.mean(loss)

    # Model for test of searching similar words
    xr = nn.Variable((1,), need_grad=False)
    hr = PF.embed(xr, n_word, n_dim, name="e1")  # feature vector for test

    # Create solver
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    monitor = M.Monitor(args.work_dir)
    monitor_loss = M.MonitorSeries(
        "Training loss", monitor, interval=args.monitor_interval)
    monitor_time = M.MonitorTimeElapsed(
        "Training time", monitor, interval=args.monitor_interval)

    # Do training
    max_epoch = args.max_epoch
    for epoch in range(max_epoch):

        # iteration per epoch
        for i in range(di.n_batch):

            # get minibatch
            xi, yi, ti = di.next()

            # learn
            solver.zero_grad()
            xl.d, yl.d, tl.d = xi, yi, ti
            loss.forward(clear_no_need_grad=True)
            loss.backward(clear_buffer=True)
            solver.update()

            # monitor
            itr = epoch * di.n_batch + i
            monitor_loss.add(itr, loss.d)
            monitor_time.add(itr)

    # Save model
    nn.save_parameters(model_file)

    # Evaluate by similarity
    max_check_words = args.max_check_words
    for i in range(max_check_words):

        # prediction
        xr.d = i
        hr.forward(clear_buffer=True)
        h = hr.d

        # similarity calculation
        w = nn.get_parameters()['e1/embed/W'].d
        s = np.sqrt((w * w).sum(1))
        w /= s.reshape((s.shape[0], 1))
        similarity = w.dot(h[0]) / s[i]

        # for understanding
        output_similar_words(itow, i, similarity)
示例#20
0
train_data_iter = data_iterator_simple(load_train_func,
                                       len(central_train),
                                       batch_size,
                                       shuffle=True,
                                       with_file_cache=False)
valid_data_iter = data_iterator_simple(load_valid_func,
                                       len(central_valid),
                                       batch_size,
                                       shuffle=True,
                                       with_file_cache=False)

x_central = nn.Variable((batch_size, ))
x_context = nn.Variable((batch_size, ))

with nn.parameter_scope('central_embedding'):
    central_embedding = PF.embed(x_central, vocab_size, embedding_size)
with nn.parameter_scope('context_embedding'):
    context_embedding = PF.embed(x_context, vocab_size, embedding_size)

with nn.parameter_scope('central_bias'):
    central_bias = PF.embed(x_central, vocab_size, 1)
with nn.parameter_scope('context_bias'):
    context_bias = PF.embed(x_context, vocab_size, 1)

dot_product = F.reshape(F.batch_matmul(
    F.reshape(central_embedding, shape=(batch_size, 1, embedding_size)),
    F.reshape(context_embedding, shape=(batch_size, embedding_size, 1))),
                        shape=(batch_size, 1))

prediction = dot_product + central_bias + context_bias

train_data_iter = data_iterator_simple(load_train_func,
                                       len(x_train),
                                       batch_size,
                                       shuffle=True,
                                       with_file_cache=False)
valid_data_iter = data_iterator_simple(load_valid_func,
                                       len(x_valid),
                                       batch_size,
                                       shuffle=True,
                                       with_file_cache=False)

x = nn.Variable((batch_size, sentence_length))
t = nn.Variable((batch_size, sentence_length, 1))
h = PF.embed(x, vocab_size, embedding_size)
h = LSTM(h, hidden, return_sequences=True)
h = TimeDistributed(PF.affine)(h, hidden, name='hidden')
y = TimeDistributed(PF.affine)(h, vocab_size, name='output')

mask = F.sum(F.sign(t), axis=2)  # do not predict 'pad'.
entropy = TimeDistributedSoftmaxCrossEntropy(y, t) * mask
count = F.sum(mask, axis=1)
loss = F.mean(F.div2(F.sum(entropy, axis=1), count))

# Create solver.
solver = S.Momentum(1e-2, momentum=0.9)
solver.set_parameters(nn.get_parameters())

# Create monitor.
from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
                                       with_file_cache=False)
valid_data_iter = data_iterator_simple(load_valid_func,
                                       len(x_valid),
                                       batch_size,
                                       shuffle=True,
                                       with_file_cache=False)

char_embedding_dim = 16
lstm_size = 650
filters = [50, 150, 200, 200]
filster_sizes = [1, 3, 5, 7]
# filters = [50, 100, 150, 200, 200, 200, 200]
# filster_sizes = [1, 2, 3, 4, 5, 6, 7]

x = nn.Variable((batch_size, sentence_length, word_length))
h = PF.embed(x, char_vocab_size, char_embedding_dim)
h = F.transpose(h, (0, 3, 1, 2))
output = []
for f, f_size in zip(filters, filster_sizes):
    _h = PF.convolution(h,
                        f,
                        kernel=(1, f_size),
                        pad=(0, f_size // 2),
                        name='conv_{}'.format(f_size))
    _h = F.max_pooling(_h, kernel=(1, word_length))
    output.append(_h)
h = F.concatenate(*output, axis=1)
h = F.transpose(h, (0, 2, 1, 3))
h = F.reshape(h, (batch_size, sentence_length, sum(filters)))
# h = PF.batch_normalization(h, axes=[2])
h = TimeDistributed(Highway)(h, name='highway1')