Ejemplo n.º 1
0
def loss(target_action, target_action_type, target_action_mask, rule_prob,
         terminal_gen_action_prob, token_prob, copy_prob):
    batch_size, max_action_length, _ = target_action.shape
    _, _, rule_num = rule_prob.shape
    _, _, token_num = token_prob.shape
    _, _, max_query_length = copy_prob.shape

    # (batch_size, max_action_length)
    target_rule, target_token, target_copy = F.split(target_action, axis=2)

    target_rule = F.reshape(target_rule, (batch_size, max_action_length, 1))
    target_rule = F.one_hot(
        target_rule, (rule_num, ))  # (batch_size, max_action_length, rule_num)
    rule_tgt_prob = rule_prob * target_rule  # (batch_size, max_action_length, rule_num)
    rule_tgt_prob = F.sum(rule_tgt_prob,
                          axis=2)  # (batch_size, max_action_length)

    target_token = F.reshape(target_token, (batch_size, max_action_length, 1))
    target_token = F.one_hot(
        target_token,
        (token_num, ))  # (batch_size, max_action_length, token_num)
    token_tgt_prob = token_prob * target_token  # (batch_size, max_action_length, token_num)
    token_tgt_prob = F.sum(token_tgt_prob,
                           axis=2)  # (batch_size, max_action_length)

    target_copy = F.reshape(target_copy, (batch_size, max_action_length, 1))
    target_copy = F.one_hot(
        target_copy, (max_query_length,
                      ))  # (batch_size, max_action_length, max_query_lenght)
    copy_tgt_prob = copy_prob * target_copy  # (batch_size, max_action_length, max_query_length)
    copy_tgt_prob = F.sum(copy_tgt_prob,
                          axis=2)  # (batch_size, max_action_length)

    # (batch_size, max_action_length)
    gen_token_prob, copy_token_prob = F.split(terminal_gen_action_prob, axis=2)
    # (batch_size, max_action_length)
    rule_mask, token_mask, copy_mask = F.split(target_action_type, axis=2)

    # (batch_size, max_action_length)
    target_prob = rule_mask * rule_tgt_prob + \
                  token_mask * gen_token_prob * token_tgt_prob + \
                  copy_mask * copy_token_prob * copy_tgt_prob
    # (batch_size, max_action_length)
    likelihood = F.log(target_prob + 1e-7)
    loss = -likelihood * target_action_mask
    # (batch_size)
    loss = F.sum(loss, axis=1)
    return F.mean(loss)
Ejemplo n.º 2
0
def softmax_cross_entropy_backward(inputs, axis=None):
    """
    Args:
      inputs (list of nn.Variable): Incomming grads/inputs to/of the forward function.
      kwargs (dict of arguments): Dictionary of the corresponding function arguments.

    Return:
      list of Variable: Return the gradients wrt inputs of the corresponding function.
    """
    dy = inputs[0]
    x0 = inputs[1]
    t0 = inputs[2]

    D = len(x0.shape)
    axis = positive_axis(axis, D)
    c0 = x0.shape[axis]
    t0_shape = [s for s in t0.shape if s != 1]
    u0 = F.reshape(t0, (-1, 1), inplace=False)
    u1 = F.one_hot(u0, (c0, ))
    to = F.reshape(u1, t0_shape + [
        c0,
    ])
    t0 = no_grad(to)
    if axis != len(to.shape) - 1:
        oaxes = [i for i in range(len(t0_shape))]
        taxes = oaxes[:axis] + [to.ndim - 1] + oaxes[axis:]
        to = F.transpose(to, taxes)
    dx0 = dy * (F.softmax(x0, axis=axis) - to)
    return dx0, None
Ejemplo n.º 3
0
    def mix_data(self, image, label):
        '''
        Define mixed data Variables.

        Args:
            image(Variable): (B, C, H, W) or (B, H, W, C)
            label(Variable): (B, 1) of integers in [0, num_classes)

        Returns:
            image(Variable): mixed data
            label(Variable): mixed label with (B, num_clases)

        '''
        if image.shape[0] % 2 != 0:
            raise ValueError(
                'Please use an even number of batch size with this implementation of mixup regularization. Given {}.'
                .format(image.shape[0]))
        image2 = image[::-1]
        label = F.one_hot(label, (self.num_classes, ))
        label2 = label[::-1]
        self.lam = nn.Variable((image.shape[0], 1, 1, 1))
        if get_nnabla_version_integer() < 10700:
            raise ValueError(
                'This does not work with nnabla version less than 1.7.0 due to [a bug](https://github.com/sony/nnabla/pull/608). Please update the nnabla version.'
            )
        llam = F.reshape(self.lam, (-1, 1))
        self.reset_mixup_ratio()  # Call it for safe.
        mimage = self.lam * image + (1 - self.lam) * image2
        mlabel = llam * label + (1 - llam) * label2
        return mimage, mlabel
Ejemplo n.º 4
0
    def random_generate(self, num_images, path):

        # Generate from the uniform prior of the base model
        indices = F.randint(low=0,
                            high=self.num_embedding,
                            shape=[num_images] + self.latent_shape)
        indices = F.reshape(indices, (-1, ), inplace=True)
        quantized = F.embed(indices, self.base_model.vq.embedding_weight)
        quantized = F.transpose(
            quantized.reshape([num_images] + self.latent_shape +
                              [quantized.shape[-1]]), (0, 3, 1, 2))

        img_gen_uniform_prior = self.base_model(quantized,
                                                quantized_as_input=True,
                                                test=True)

        # Generate images using pixelcnn prior
        indices = nn.Variable.from_numpy_array(
            np.zeros(shape=[num_images] + self.latent_shape))
        labels = F.randint(low=0, high=self.num_classes, shape=(num_images, 1))
        labels = F.one_hot(labels, shape=(self.num_classes, ))

        # Sample from pixelcnn - pixel by pixel
        import torch  # Numpy behavior is different and not giving correct output
        for i in range(self.latent_shape[0]):
            for j in range(self.latent_shape[1]):
                quantized = F.embed(indices.reshape((-1, )),
                                    self.base_model.vq.embedding_weight)
                quantized = F.transpose(
                    quantized.reshape([num_images] + self.latent_shape +
                                      [quantized.shape[-1]]), (0, 3, 1, 2))
                indices_sample = self.prior(quantized, labels)
                indices_prob = F.reshape(indices_sample,
                                         indices.shape +
                                         (indices_sample.shape[-1], ),
                                         inplace=True)[:, i, j]
                indices_prob = F.softmax(indices_prob)

                indices_prob_tensor = torch.from_numpy(indices_prob.d)
                sample = indices_prob_tensor.multinomial(1).squeeze().numpy()
                indices[:, i, j] = sample

        print(indices.d)
        quantized = F.embed(indices.reshape((-1, )),
                            self.base_model.vq.embedding_weight)
        quantized = F.transpose(
            quantized.reshape([num_images] + self.latent_shape +
                              [quantized.shape[-1]]), (0, 3, 1, 2))

        img_gen_pixelcnn_prior = self.base_model(quantized,
                                                 quantized_as_input=True,
                                                 test=True)

        self.save_image(img_gen_uniform_prior,
                        os.path.join(path, 'generate_uniform.png'))
        self.save_image(img_gen_pixelcnn_prior,
                        os.path.join(path, 'generate_pixelcnn.png'))

        print('Random labels generated for pixelcnn prior:',
              list(F.max(labels, axis=1, only_index=True).d))
Ejemplo n.º 5
0
def encode_inputs(inst_label,
                  id_label,
                  n_ids,
                  use_encoder=False,
                  channel_last=False):
    """
    :param inst_label: (N, H, W) or (N, H, W, 1)
    :param id_label: (N, H, W) or (N, H, W, 1)
    :param use_encoder: boolean
    :return:
    """
    # id (index) -> onehot
    _check_intput(id_label)
    if len(id_label.shape) == 3:
        id_label = id_label.reshape(id_label.shape + (1, ))
    id_onehot = F.one_hot(id_label, shape=(n_ids, ))

    # inst -> boundary map
    _check_intput(inst_label)
    bm = inst_to_boundary(inst_label)
    if len(bm.shape) == 3:
        bm = bm.reshape(bm.shape + (1, ))

    if use_encoder:
        # todo: implement encoder network
        pass

    if channel_last:
        return id_onehot, bm

    return F.transpose(id_onehot, (0, 3, 1, 2)), F.transpose(bm, (0, 3, 1, 2))
Ejemplo n.º 6
0
def create_network(batch_size, num_dilations, learning_rate):
    # model
    x = nn.Variable(shape=(batch_size, data_config.duration, 1))  # (B, T, 1)
    onehot = F.one_hot(x, shape=(data_config.q_bit_len, ))  # (B, T, C)
    wavenet_input = F.transpose(onehot, (0, 2, 1))  # (B, C, T)

    # speaker embedding
    s_emb = None

    net = WaveNet(num_dilations)
    wavenet_output = net(wavenet_input, s_emb)

    pred = F.transpose(wavenet_output, (0, 2, 1))

    # (B, T, 1)
    t = nn.Variable(shape=(batch_size, data_config.duration, 1))

    loss = F.mean(F.softmax_cross_entropy(pred, t))
    # loss.visit(PrintFunc())

    # Create Solver.
    solver = S.Adam(learning_rate)
    solver.set_parameters(nn.get_parameters())

    return x, t, loss, solver
Ejemplo n.º 7
0
    def build_train_graph(self, batch):
        self.solver = S.Adam(self.learning_rate)

        obs, action, reward, terminal, newobs = batch
        # Create input variables
        s = nn.Variable(obs.shape)
        a = nn.Variable(action.shape)
        r = nn.Variable(reward.shape)
        t = nn.Variable(terminal.shape)
        snext = nn.Variable(newobs.shape)
        with nn.parameter_scope(self.name_q):
            q = self.q_builder(s, self.num_actions, test=False)
            self.solver.set_parameters(nn.get_parameters())
        with nn.parameter_scope(self.name_qnext):
            qnext = self.q_builder(snext, self.num_actions, test=True)
        qnext.need_grad = False
        clipped_r = F.minimum_scalar(F.maximum_scalar(
            r, -self.clip_reward), self.clip_reward)
        q_a = F.sum(
            q * F.one_hot(F.reshape(a, (-1, 1), inplace=False), (q.shape[1],)), axis=1)
        target = clipped_r + self.gamma * (1 - t) * F.max(qnext, axis=1)
        loss = F.mean(F.huber_loss(q_a, target))
        Variables = namedtuple(
            'Variables', ['s', 'a', 'r', 't', 'snext', 'q', 'loss'])
        self.v = Variables(s, a, r, t, snext, q, loss)
        self.sync_models()
        self.built = True
Ejemplo n.º 8
0
    def _build(self):
        # infer variable
        self.infer_obs_t = nn.Variable((1, 4, 84, 84))
        # inference output
        self.infer_qs_t = self.q_function(self.infer_obs_t, self.num_actions,
                                          self.num_heads, 'q_func')
        self.infer_all = F.sink(*self.infer_qs_t)

        # train variables
        self.obss_t = nn.Variable((self.batch_size, 4, 84, 84))
        self.acts_t = nn.Variable((self.batch_size, 1))
        self.rews_tp1 = nn.Variable((self.batch_size, 1))
        self.obss_tp1 = nn.Variable((self.batch_size, 4, 84, 84))
        self.ters_tp1 = nn.Variable((self.batch_size, 1))
        self.weights = nn.Variable((self.batch_size, self.num_heads))

        # training output
        qs_t = self.q_function(self.obss_t, self.num_actions, self.num_heads,
                               'q_func')
        qs_tp1 = q_function(self.obss_tp1, self.num_actions, self.num_heads,
                            'target')
        stacked_qs_t = F.transpose(F.stack(*qs_t), [1, 0, 2])
        stacked_qs_tp1 = F.transpose(F.stack(*qs_tp1), [1, 0, 2])

        # select one dimension
        a_one_hot = F.reshape(F.one_hot(self.acts_t, (self.num_actions, )),
                              (-1, 1, self.num_actions))
        # mask output
        q_t_selected = F.sum(stacked_qs_t * a_one_hot, axis=2)
        q_tp1_best = F.max(stacked_qs_tp1, axis=2)
        q_tp1_best.need_grad = False

        # reward clipping
        clipped_rews_tp1 = clip_by_value(self.rews_tp1, -1.0, 1.0)

        # loss calculation
        y = clipped_rews_tp1 + self.gamma * q_tp1_best * (1.0 - self.ters_tp1)
        td = F.huber_loss(q_t_selected, y)
        self.loss = F.mean(F.sum(td * self.weights, axis=1))

        # optimizer
        self.solver = S.RMSprop(self.lr, 0.95, 1e-2)

        # weights and biases
        with nn.parameter_scope('q_func'):
            self.params = nn.get_parameters()
            self.head_params = []
            for i in range(self.num_heads):
                with nn.parameter_scope('head%d' % i):
                    self.head_params.append(nn.get_parameters())
            with nn.parameter_scope('shared'):
                self.shared_params = nn.get_parameters()
        with nn.parameter_scope('target'):
            self.target_params = nn.get_parameters()

        # set q function parameters to solver
        self.solver.set_parameters(self.params)
Ejemplo n.º 9
0
def test_one_hot_forward(seed, inshape, shape, ctx, func_name):
    rng = np.random.RandomState(seed)
    # Input
    input = rng.randint(0, shape[0], size=inshape)
    vinput = nn.Variable(input.shape, need_grad=False)
    vinput.d = input
    with nn.context_scope(ctx), nn.auto_forward():
        o = F.one_hot(vinput, shape)
    r = ref_one_hot(input, shape)
    assert np.allclose(o.d, r)
    assert func_name == o.parent.name
Ejemplo n.º 10
0
def test_one_hot_forward(seed, inshape, shape, ctx, func_name):
    # Input
    input = np.zeros(inshape, dtype=int)
    rng = np.random.RandomState(seed)

    if len(shape) != inshape[-1]:
        # input inshape and shape don't match.
        with pytest.raises(RuntimeError):
            y = F.one_hot(nn.Variable(input.shape), shape)
    else:
        for i in range(inshape[-1]):
            input[:, i] = rng.randint(0, shape[i], size=inshape[0])
        vinput = nn.Variable(input.shape, need_grad=False)
        vinput.d = input

        with nn.context_scope(ctx), nn.auto_forward():
            o = F.one_hot(vinput, shape)
        r = ref_one_hot(input, shape)
        assert np.allclose(o.d, r)
        assert func_name == o.parent.name
Ejemplo n.º 11
0
def test_one_hot_forward(seed, inshape, shape, ctx, func_name):
    rng = np.random.RandomState(seed)
    # Input
    input = rng.randint(0, shape[0], size=inshape)
    vinput = nn.Variable(input.shape, need_grad=False)
    vinput.d = input
    with nn.context_scope(ctx), nn.auto_forward():
        o = F.one_hot(vinput, shape)
    r = ref_one_hot(input, shape)
    assert np.allclose(o.d, r)
    assert func_name == o.parent.name
Ejemplo n.º 12
0
    def forward_pass(self, img_var, labels):
        enc_indices, quantized = self.base_model(img_var,
                                                 return_encoding_indices=True,
                                                 test=True)
        labels_var = nn.Variable(labels.shape)
        if isinstance(labels, nn.NdArray):
            labels_var.data = labels
        else:
            labels_var.d = labels
        labels_var = F.one_hot(labels_var, shape=(self.num_classes, ))
        enc_recon = self.prior(quantized, labels_var)
        loss = F.mean(F.softmax_cross_entropy(enc_recon, enc_indices))

        return loss, enc_indices, enc_recon
Ejemplo n.º 13
0
def model_tweak_digitscaps(batch_size):
    '''
    '''
    image = nn.Variable((batch_size, 1, 28, 28))
    label = nn.Variable((batch_size, 1))
    x = image / 255.0
    t_onehot = F.one_hot(label, (10,))
    with nn.parameter_scope("capsnet"):
        _, _, _, caps, _ = model.capsule_net(
            x, test=True, aug=False, grad_dynamic_routing=True)
    noise = nn.Variable((batch_size, 1, caps.shape[2]))
    with nn.parameter_scope("capsnet_reconst"):
        recon = model.capsule_reconstruction(caps, t_onehot, noise)
    return image, label, noise, recon
Ejemplo n.º 14
0
def mlp_gradient_synthesizer(x, y=None, test=False):
    maps = x.shape[1]
    if y is not None:
        h = F.one_hot(y, (10, ))
        h = F.concatenate(*[x, y], axis=1)
    else:
        h = x
    with nn.parameter_scope("gs"):
        h = act_bn_linear(h, maps, test, name="fc0")
        h = act_bn_linear(h, maps, test, name="fc1")
        w_init = ConstantInitializer(0)
        b_init = ConstantInitializer(0)
        g_pred = PF.affine(h, maps, w_init=w_init, b_init=b_init, name="fc")
        g_pred.persistent = True
    return g_pred
Ejemplo n.º 15
0
    def _build(self):
        # infer variable
        self.infer_obs_t = nn.Variable((1, 4, 84, 84))
        # inference output
        self.infer_q_t = self.q_function(self.infer_obs_t,
                                         self.num_actions,
                                         scope='q_func')

        # train variables
        self.obss_t = nn.Variable((self.batch_size, 4, 84, 84))
        self.acts_t = nn.Variable((self.batch_size, 1))
        self.rews_tp1 = nn.Variable((self.batch_size, 1))
        self.obss_tp1 = nn.Variable((self.batch_size, 4, 84, 84))
        self.ters_tp1 = nn.Variable((self.batch_size, 1))
        self.weights = nn.Variable((self.batch_size, 1))

        # training output
        q_t = self.q_function(self.obss_t, self.num_actions, scope='q_func')
        q_tp1 = self.q_function(self.obss_tp1,
                                self.num_actions,
                                scope='target_q_func')

        # select one dimension
        a_t_one_hot = F.one_hot(self.acts_t, (self.num_actions, ))
        q_t_selected = F.sum(q_t * a_t_one_hot, axis=1, keepdims=True)
        q_tp1_best = F.max(q_tp1, axis=1, keepdims=True)

        # loss calculation
        y = self.rews_tp1 + self.gamma * q_tp1_best * (1.0 - self.ters_tp1)
        self.td = q_t_selected - y
        self.loss = F.sum(F.huber_loss(q_t_selected, y) * self.weights)
        self.loss_sink = F.sink(self.td, self.loss)

        # optimizer
        self.solver = S.RMSprop(self.lr, 0.95, 1e-2)

        # weights and biases
        with nn.parameter_scope('q_func'):
            self.params = nn.get_parameters()
        with nn.parameter_scope('target_q_func'):
            self.target_params = nn.get_parameters()

        # set q function parameters to solver
        self.solver.set_parameters(self.params)
Ejemplo n.º 16
0
    def __call__(self, x, return_encoding_indices=False):

        x = F.transpose(x, (0, 2, 3, 1))
        x_flat = x.reshape((-1, self.embedding_dim))

        x_flat_squared = F.broadcast(F.sum(x_flat**2, axis=1, keepdims=True),
                                     (x_flat.shape[0], self.num_embedding))
        emb_wt_squared = F.transpose(
            F.sum(self.embedding_weight**2, axis=1, keepdims=True), (1, 0))

        distances = x_flat_squared + emb_wt_squared - 2 * \
            F.affine(x_flat, F.transpose(self.embedding_weight, (1, 0)))

        encoding_indices = F.min(distances,
                                 only_index=True,
                                 axis=1,
                                 keepdims=True)
        encoding_indices.need_grad = False

        quantized = F.embed(
            encoding_indices.reshape(encoding_indices.shape[:-1]),
            self.embedding_weight).reshape(x.shape)

        if return_encoding_indices:
            return encoding_indices, F.transpose(quantized, (0, 3, 1, 2))

        encodings = F.one_hot(encoding_indices, (self.num_embedding, ))

        e_latent_loss = F.mean(
            F.squared_error(quantized.get_unlinked_variable(need_grad=False),
                            x))
        q_latent_loss = F.mean(
            F.squared_error(quantized,
                            x.get_unlinked_variable(need_grad=False)))
        loss = q_latent_loss + self.commitment_cost * e_latent_loss

        quantized = x + (quantized - x).get_unlinked_variable(need_grad=False)

        avg_probs = F.mean(encodings, axis=0)
        perplexity = F.exp(-F.sum(avg_probs * F.log(avg_probs + 1.0e-10)))

        return loss, F.transpose(quantized,
                                 (0, 3, 1, 2)), perplexity, encodings
Ejemplo n.º 17
0
    def __init__(self, num_actions, num_envs, batch_size, v_coeff, ent_coeff,
                 lr_scheduler):
        # inference graph
        self.infer_obs_t = nn.Variable((num_envs, 4, 84, 84))
        self.infer_pi_t,\
        self.infer_value_t = cnn_network(self.infer_obs_t, num_actions,
                                         'network')
        self.infer_t = F.sink(self.infer_pi_t, self.infer_value_t)

        # evaluation graph
        self.eval_obs_t = nn.Variable((1, 4, 84, 84))
        self.eval_pi_t, _ = cnn_network(self.eval_obs_t, num_actions,
                                        'network')

        # training graph
        self.obss_t = nn.Variable((batch_size, 4, 84, 84))
        self.acts_t = nn.Variable((batch_size, 1))
        self.rets_t = nn.Variable((batch_size, 1))
        self.advs_t = nn.Variable((batch_size, 1))

        pi_t, value_t = cnn_network(self.obss_t, num_actions, 'network')

        # value loss
        l2loss = F.squared_error(value_t, self.rets_t)
        self.value_loss = v_coeff * F.mean(l2loss)

        # policy loss
        log_pi_t = F.log(pi_t + 1e-20)
        a_one_hot = F.one_hot(self.acts_t, (num_actions, ))
        log_probs_t = F.sum(log_pi_t * a_one_hot, axis=1, keepdims=True)
        self.pi_loss = F.mean(log_probs_t * self.advs_t)

        # KL loss
        entropy = -ent_coeff * F.mean(F.sum(pi_t * log_pi_t, axis=1))

        self.loss = self.value_loss - self.pi_loss - entropy

        self.params = nn.get_parameters()
        self.solver = S.RMSprop(lr_scheduler(0.0), 0.99, 1e-5)
        self.solver.set_parameters(self.params)
        self.lr_scheduler = lr_scheduler
Ejemplo n.º 18
0
def cnn_gradient_synthesizer(x, y=None, test=False):
    bs = x.shape[0]
    maps = x.shape[1]
    s0, s1 = x.shape[2:]
    if y is not None:
        h = F.one_hot(y, (10, ))
        h = F.reshape(h, (bs, 10, 1, 1))
        h = F.broadcast(h, (bs, 10, s0, s1))
        h = F.concatenate(*[x, h], axis=1)
    else:
        h = x
    with nn.parameter_scope("gs"):
        h = act_bn_conv(h, maps, test, name="conv0")
        w_init = ConstantInitializer(0)
        b_init = ConstantInitializer(0)
        g_pred = PF.convolution(h,
                                maps,
                                kernel=(3, 3),
                                pad=(1, 1),
                                w_init=w_init,
                                b_init=b_init,
                                name="conv")
        g_pred.persistent = True
    return g_pred
Ejemplo n.º 19
0
    def define_network(self):

        if self.use_inst:
            obj_onehot, bm = encode_inputs(self.ist_mask,
                                           self.obj_mask,
                                           n_ids=self.conf.n_class)

            mask = F.concatenate(obj_onehot, bm, axis=1)
        else:
            om = self.obj_mask
            if len(om.shape) == 3:
                om = F.reshape(om, om.shape + (1, ))
            obj_onehot = F.one_hot(om, shape=(self.conf.n_class, ))
            mask = F.transpose(obj_onehot, (0, 3, 1, 2))

        generator = SpadeGenerator(self.conf.g_ndf,
                                   image_shape=self.conf.image_shape)
        z = F.randn(shape=(self.conf.batch_size, self.conf.z_dim))
        fake = generator(z, mask)

        # Pixel intensities of fake are [-1, 1]. Rescale it to [0, 1]
        fake = (fake + 1) / 2

        return fake
Ejemplo n.º 20
0
def network(x, d1, c1, d2, c2, test=False):
    # Input:x -> 1
    # OneHot -> 687
    h = F.one_hot(x, (687, ))

    # LSTM1 -> 200
    with nn.parameter_scope('LSTM1'):
        h = network_LSTM(h, d1, c1, 687, 100, test)

    # Slice -> 100
    h1 = F.slice(h, (0, ), (100, ), (1, ))

    # h2:CellOut -> 100
    h2 = F.slice(h, (100, ), (200, ), (1, ))

    # LSTM2 -> 128
    with nn.parameter_scope('LSTM2'):
        h3 = network_LSTM(h1, d2, c2, 100, 64, test)

    # h4:DelayOut
    h4 = F.identity(h1)

    # Slice_2 -> 64
    h5 = F.slice(h3, (0, ), (64, ), (1, ))

    # h6:CellOut_2 -> 64
    h6 = F.slice(h3, (64, ), (128, ), (1, ))

    # Affine_2 -> 687
    h7 = PF.affine(h5, (687, ), name='Affine_2')

    # h8:DelayOut_2
    h8 = F.identity(h5)
    # h7:Softmax
    h7 = F.softmax(h7)
    return h2, h4, h6, h8, h7
Ejemplo n.º 21
0
def train():
    args = get_args()

    # Set context.
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in {}:{}".format(args.context, args.type_config))
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    data_iterator = data_iterator_librispeech(args.batch_size, args.data_dir)
    _data_source = data_iterator._data_source  # dirty hack...

    # model
    x = nn.Variable(
        shape=(args.batch_size, data_config.duration, 1))  # (B, T, 1)
    onehot = F.one_hot(x, shape=(data_config.q_bit_len, ))  # (B, T, C)
    wavenet_input = F.transpose(onehot, (0, 2, 1))  # (B, C, T)

    # speaker embedding
    if args.use_speaker_id:
        s_id = nn.Variable(shape=(args.batch_size, 1))
        with nn.parameter_scope("speaker_embedding"):
            s_emb = PF.embed(s_id, n_inputs=_data_source.n_speaker,
                             n_features=WavenetConfig.speaker_dims)
            s_emb = F.transpose(s_emb, (0, 2, 1))
    else:
        s_emb = None

    net = WaveNet()
    wavenet_output = net(wavenet_input, s_emb)

    pred = F.transpose(wavenet_output, (0, 2, 1))

    # (B, T, 1)
    t = nn.Variable(shape=(args.batch_size, data_config.duration, 1))

    loss = F.mean(F.softmax_cross_entropy(pred, t))

    # for generation
    prob = F.softmax(pred)

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)

    # setup save env.
    audio_save_path = os.path.join(os.path.abspath(
        args.model_save_path), "audio_results")
    if audio_save_path and not os.path.exists(audio_save_path):
        os.makedirs(audio_save_path)

    # Training loop.
    for i in range(args.max_iter):
        # todo: validation

        x.d, _speaker, t.d = data_iterator.next()
        if args.use_speaker_id:
            s_id.d = _speaker.reshape(-1, 1)

        solver.zero_grad()
        loss.forward(clear_no_need_grad=True)
        loss.backward(clear_buffer=True)
        solver.update()

        loss.data.cast(np.float32, ctx)
        monitor_loss.add(i, loss.d.copy())

        if i % args.model_save_interval == 0:
            prob.forward()
            audios = mu_law_decode(
                np.argmax(prob.d, axis=-1), quantize=data_config.q_bit_len)  # (B, T)
            save_audio(audios, i, audio_save_path)
Ejemplo n.º 22
0
def cond_att_lstm(x,
                  parent_index,
                  mask,
                  context,
                  context_mask,
                  state_size,
                  att_hidden_size,
                  initial_state=None,
                  initial_cell=None,
                  hist=None,
                  dropout=0,
                  train=True,
                  w_init=None,
                  inner_w_init=None,
                  b_init=I.ConstantInitializer(0),
                  forget_bias_init=I.ConstantInitializer(1)):
    """
    x: (batch_size, length, input_size)
    parent_index: (batch_size, length)
    mask: (batch_size, length)
    context: (batch_size, context_length, context_size)
    context_mask: (batch_size, context_length)
    hist: (batch_size, l, state_size)
    """
    batch_size, length, input_size = x.shape
    _, context_length, context_size = context.shape

    if w_init is None:
        w_init = I.UniformInitializer(
            I.calc_uniform_lim_glorot(input_size, state_size))
    if inner_w_init is None:
        inner_w_init = orthogonal

    retain_prob = 1.0 - dropout
    z_w = nn.Variable((batch_size, 4, input_size), need_grad=False)
    z_w.d = 1
    z_u = nn.Variable((batch_size, 4, state_size), need_grad=False)
    z_u.d = 1

    if dropout > 0:
        if train:
            z_w = F.dropout(z_w, p=retain_prob)
            z_u = F.dropout(z_u, p=retain_prob)
        z_w *= retain_prob
        z_u *= retain_prob

    z_w = F.reshape(z_w, (batch_size, 4, 1, input_size))
    z_w = F.broadcast(z_w, (batch_size, 4, length, input_size))
    z_w = F.split(z_w, axis=1)
    z_u = F.split(z_u, axis=1)
    xi = z_w[0] * x
    xf = z_w[1] * x
    xc = z_w[2] * x
    xo = z_w[3] * x

    with nn.parameter_scope("cond_att_lstm"):
        # (batch_size, length, state_size)
        with nn.parameter_scope("lstm"):
            xi = PF.affine(
                xi,
                state_size,
                base_axis=2,
                w_init=w_init,
                b_init=b_init,
                name="Wi")
            xf = PF.affine(
                xf,
                state_size,
                base_axis=2,
                w_init=w_init,
                b_init=forget_bias_init,
                name="Wf")
            xc = PF.affine(
                xc,
                state_size,
                base_axis=2,
                w_init=w_init,
                b_init=b_init,
                name="Wc")
            xo = PF.affine(
                xo,
                state_size,
                base_axis=2,
                w_init=w_init,
                b_init=b_init,
                name="Wo")

        with nn.parameter_scope("context"):
            # context_att_trans: (batch_size, context_size, att_hidden_size)
            context_att_trans = PF.affine(
                context,
                att_hidden_size,
                base_axis=2,
                w_init=w_init,
                b_init=b_init,
                name="layer1_c")

    if initial_state is None:
        h = nn.Variable((batch_size, state_size), need_grad=False)
        h.data.zero()
    else:
        h = initial_state

    if initial_cell is None:
        c = nn.Variable((batch_size, state_size), need_grad=False)
        c.data.zero()
    else:
        c = initial_cell

    if hist is None:
        hist = nn.Variable((batch_size, 1, state_size), need_grad=False)
        hist.data.zero()

    # (batch_size, state_size)
    xi = split(xi, axis=1)
    xf = split(xf, axis=1)
    xc = split(xc, axis=1)
    xo = split(xo, axis=1)
    mask = F.reshape(mask, [batch_size, length, 1])  # (batch_size, length, 1)
    mask = F.broadcast(mask, [batch_size, length, state_size])
    # (batch_size, state_size)
    mask = split(mask, axis=1)
    # (batch_size, max_action_length)
    parent_index = parent_index + 1  # index == 0 means that parent is root
    # (batch_size)
    parent_index = split(parent_index, axis=1)

    hs = []
    cs = []
    ctx = []

    for i, f, c2, o, m, p in zip(xi, xf, xc, xo, mask, parent_index):
        h_num = hist.shape[1]
        with nn.parameter_scope("context"):
            h_att_trans = PF.affine(
                h,
                att_hidden_size,
                with_bias=False,
                w_init=w_init,
                name="layer1_h")  # (batch_size, att_hidden_size)
            h_att_trans = F.reshape(h_att_trans,
                                    (batch_size, 1, att_hidden_size))
            h_att_trans = F.broadcast(
                h_att_trans, (batch_size, context_length, att_hidden_size))
            att_hidden = F.tanh(context_att_trans + h_att_trans)
            att_raw = PF.affine(
                att_hidden, 1, base_axis=2, w_init=w_init,
                b_init=b_init)  # (batch_size, context_length, 1)
            att_raw = F.reshape(att_raw, (batch_size, context_length))
            ctx_att = F.exp(att_raw - F.max(att_raw, axis=1, keepdims=True))
            ctx_att = ctx_att * context_mask
            ctx_att = ctx_att / F.sum(ctx_att, axis=1, keepdims=True)
            ctx_att = F.reshape(ctx_att, (batch_size, context_length, 1))
            ctx_att = F.broadcast(ctx_att,
                                  (batch_size, context_length, context_size))
            ctx_vec = F.sum(
                context * ctx_att, axis=1)  # (batch_size, context_size)

        # parent_history
        p = F.reshape(p, (batch_size, 1))
        p = F.one_hot(p, (h_num, ))
        p = F.reshape(p, (batch_size, 1, h_num))
        par_h = F.batch_matmul(p, hist)  # [batch_size, 1, state_size]
        par_h = F.reshape(par_h, (batch_size, state_size))

        with nn.parameter_scope("lstm"):
            i_t = PF.affine(
                z_u[0] * h,
                state_size,
                w_init=inner_w_init(state_size, state_size),
                with_bias=False,
                name="Ui")
            i_t += PF.affine(
                ctx_vec,
                state_size,
                w_init=inner_w_init(context_size, state_size),
                with_bias=False,
                name="Ci")
            i_t += PF.affine(
                par_h,
                state_size,
                w_init=inner_w_init(state_size, state_size),
                with_bias=False,
                name="Pi")
            i_t = F.sigmoid(i + i_t)
            f_t = PF.affine(
                z_u[1] * h,
                state_size,
                w_init=inner_w_init(state_size, state_size),
                with_bias=False,
                name="Uf")
            f_t += PF.affine(
                ctx_vec,
                state_size,
                w_init=inner_w_init(context_size, state_size),
                with_bias=False,
                name="Cf")
            f_t += PF.affine(
                par_h,
                state_size,
                w_init=inner_w_init(state_size, state_size),
                with_bias=False,
                name="Pf")
            f_t = F.sigmoid(f + f_t)
            c_t = PF.affine(
                z_u[2] * h,
                state_size,
                w_init=inner_w_init(state_size, state_size),
                with_bias=False,
                name="Uc")
            c_t += PF.affine(
                ctx_vec,
                state_size,
                w_init=inner_w_init(context_size, state_size),
                with_bias=False,
                name="Cc")
            c_t += PF.affine(
                par_h,
                state_size,
                w_init=inner_w_init(state_size, state_size),
                with_bias=False,
                name="Pc")
            c_t = f_t * c + i_t * F.tanh(c2 + c_t)
            o_t = PF.affine(
                z_u[3] * h,
                state_size,
                w_init=inner_w_init(state_size, state_size),
                with_bias=False,
                name="Uo")
            o_t += PF.affine(
                ctx_vec,
                state_size,
                w_init=inner_w_init(context_size, state_size),
                with_bias=False,
                name="Co")
            o_t += PF.affine(
                par_h,
                state_size,
                w_init=inner_w_init(state_size, state_size),
                with_bias=False,
                name="Po")
            o_t = F.sigmoid(o + o_t)
            h_t = o_t * F.tanh(c_t)

            h_t = (1 - m) * h + m * h_t
            c_t = (1 - m) * c + m * c_t
            h = h_t
            c = c_t
            h_t = F.reshape(h_t, (batch_size, 1, state_size), inplace=False)
            c_t = F.reshape(c_t, (batch_size, 1, state_size), inplace=False)
            ctx_vec = F.reshape(
                ctx_vec, (batch_size, 1, context_size), inplace=False)
            hs.append(h_t)
            cs.append(c_t)
            ctx.append(ctx_vec)

            hist = F.concatenate(
                hist, h_t, axis=1)  # (batch_size, h_num + 1, state_size)

    return concatenate(
        *hs, axis=1), concatenate(
            *cs, axis=1), concatenate(
                *ctx, axis=1), hist
Ejemplo n.º 23
0
def train():
    '''
    Main script.
    '''
    args = get_args()

    from numpy.random import seed
    seed(0)

    # Get context.
    from nnabla.ext_utils import get_extension_context
    logger.info("Running in %s" % args.context)
    ctx = get_extension_context(args.context,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    # TRAIN
    image = nn.Variable([args.batch_size, 1, 28, 28])
    label = nn.Variable([args.batch_size, 1])
    x = image / 255.0
    t_onehot = F.one_hot(label, (10, ))
    with nn.parameter_scope("capsnet"):
        c1, pcaps, u_hat, caps, pred = model.capsule_net(
            x,
            test=False,
            aug=True,
            grad_dynamic_routing=args.grad_dynamic_routing)
    with nn.parameter_scope("capsnet_reconst"):
        recon = model.capsule_reconstruction(caps, t_onehot)
    loss_margin, loss_reconst, loss = model.capsule_loss(
        pred, t_onehot, recon, x)
    pred.persistent = True

    # TEST
    # Create input variables.
    vimage = nn.Variable([args.batch_size, 1, 28, 28])
    vlabel = nn.Variable([args.batch_size, 1])
    vx = vimage / 255.0
    with nn.parameter_scope("capsnet"):
        _, _, _, _, vpred = model.capsule_net(vx, test=True, aug=False)

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    train_iter = int(60000 / args.batch_size)
    val_iter = int(10000 / args.batch_size)
    logger.info("#Train: {} #Validation: {}".format(train_iter, val_iter))
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=1)
    monitor_mloss = MonitorSeries("Training margin loss", monitor, interval=1)
    monitor_rloss = MonitorSeries("Training reconstruction loss",
                                  monitor,
                                  interval=1)
    monitor_err = MonitorSeries("Training error", monitor, interval=1)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=1)
    monitor_verr = MonitorSeries("Test error", monitor, interval=1)
    monitor_lr = MonitorSeries("Learning rate", monitor, interval=1)

    # To_save_nnp
    m_image, m_label, m_noise, m_recon = model_tweak_digitscaps(
        args.batch_size)
    contents = save_nnp({
        'x1': m_image,
        'x2': m_label,
        'x3': m_noise
    }, {'y': m_recon}, args.batch_size)
    save.save(os.path.join(args.monitor_path, 'capsnet_epoch0_result.nnp'),
              contents)

    # Initialize DataIterator for MNIST.
    from numpy.random import RandomState
    data = data_iterator_mnist(args.batch_size, True, rng=RandomState(1223))
    vdata = data_iterator_mnist(args.batch_size, False)
    start_point = 0

    if args.checkpoint is not None:
        # load weights and solver state info from specified checkpoint file.
        start_point = load_checkpoint(args.checkpoint, solver)
    # Training loop.
    for e in range(start_point, args.max_epochs):

        # Learning rate decay
        learning_rate = solver.learning_rate()
        if e != 0:
            learning_rate *= 0.9
        solver.set_learning_rate(learning_rate)
        monitor_lr.add(e, learning_rate)

        # Training
        train_error = 0.0
        train_loss = 0.0
        train_mloss = 0.0
        train_rloss = 0.0
        for i in range(train_iter):
            image.d, label.d = data.next()
            solver.zero_grad()
            loss.forward(clear_no_need_grad=True)
            loss.backward(clear_buffer=True)
            solver.update()
            train_error += categorical_error(pred.d, label.d)
            train_loss += loss.d
            train_mloss += loss_margin.d
            train_rloss += loss_reconst.d
        train_error /= train_iter
        train_loss /= train_iter
        train_mloss /= train_iter
        train_rloss /= train_iter

        # Validation
        val_error = 0.0
        for j in range(val_iter):
            vimage.d, vlabel.d = vdata.next()
            vpred.forward(clear_buffer=True)
            val_error += categorical_error(vpred.d, vlabel.d)
        val_error /= val_iter

        # Monitor
        monitor_time.add(e)
        monitor_loss.add(e, train_loss)
        monitor_mloss.add(e, train_mloss)
        monitor_rloss.add(e, train_rloss)
        monitor_err.add(e, train_error)
        monitor_verr.add(e, val_error)
        save_checkpoint(args.monitor_path, e, solver)

    # To_save_nnp
    contents = save_nnp({
        'x1': m_image,
        'x2': m_label,
        'x3': m_noise
    }, {'y': m_recon}, args.batch_size)
    save.save(os.path.join(args.monitor_path, 'capsnet_result.nnp'), contents)
Ejemplo n.º 24
0
def train():
    """
    Main script.

    Steps:

    * Parse command line arguments.
    * Specify contexts for computation.
    * Initialize DataIterator.
    * Construct a computation graph for training and one for validation.
    * Initialize solver and set parameter variables to that.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop
      * Set parameter gradients zero
      * Execute backprop.
      * Solver updates parameters by using gradients computed by backprop.
      * Compute training error
    """
    # Parse args
    args = get_args()
    n_valid_samples = 10000
    bs_valid = args.batch_size
    extension_module = args.context
    ctx = get_extension_context(extension_module,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)

    # Dataset
    data_iterator = data_iterator_cifar10
    n_class = 10

    # Model architecture
    if args.net == "resnet18":
        prediction = functools.partial(resnet18_prediction,
                                       ncls=n_class,
                                       nmaps=64,
                                       act=F.relu)
    if args.net == "resnet34":
        prediction = functools.partial(resnet34_prediction,
                                       ncls=n_class,
                                       nmaps=64,
                                       act=F.relu)

    # Create training graphs
    test = False
    if args.mixtype == "mixup":
        mdl = MixupLearning(args.batch_size, alpha=args.alpha)
    elif args.mixtype == "cutmix":
        mdl = CutmixLearning((args.batch_size, 3, 32, 32),
                             alpha=args.alpha,
                             cutmix_prob=1.0)
    elif args.mixtype == "vhmixup":
        mdl = VHMixupLearning((args.batch_size, 3, 32, 32), alpha=args.alpha)
    else:
        print("[ERROR] Unknown mixtype: " + args.mixtype)
        return
    image_train = nn.Variable((args.batch_size, 3, 32, 32))
    label_train = nn.Variable((args.batch_size, 1))
    mix_image, mix_label = mdl.mix_data(single_image_augment(image_train),
                                        F.one_hot(label_train, (n_class, )))
    pred_train = prediction(mix_image, test)
    loss_train = mdl.loss(pred_train, mix_label)
    input_train = {"image": image_train, "label": label_train}

    # Create validation graph
    test = True
    image_valid = nn.Variable((bs_valid, 3, 32, 32))
    pred_valid = prediction(image_valid, test)
    input_valid = {"image": image_valid}

    # Solvers
    if args.solver == "Adam":
        solver = S.Adam()
    elif args.solver == "Momentum":
        solver = S.Momentum(lr=args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Create monitor
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.save_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10)
    monitor_verr = MonitorSeries("Test error", monitor, interval=1)

    # Data Iterator
    tdata = data_iterator(args.batch_size, True)
    vdata = data_iterator(args.batch_size, False)

    print("Size of the training data: %d " % tdata.size)
    # Training-loop
    for i in range(args.max_iter):
        # Forward/Zerograd/Backward
        image, label = tdata.next()
        input_train["image"].d = image
        input_train["label"].d = label
        mdl.set_mix_ratio()
        loss_train.forward()
        solver.zero_grad()
        loss_train.backward()

        # Model update by solver
        if args.solver == "Momentum":
            if i == args.max_iter / 2:
                solver.set_learning_rate(args.learning_rate / 10.0)
            if i == args.max_iter / 4 * 3:
                solver.set_learning_rate(args.learning_rate / 10.0**2)
        solver.update()

        # Validation
        if (i + 1) % args.val_interval == 0 or i == 0:
            ve = 0.
            vdata._reset()
            vdata_pred = np.zeros((n_valid_samples, n_class))
            vdata_label = np.zeros((n_valid_samples, 1), dtype=np.int32)
            for j in range(0, n_valid_samples, args.batch_size):
                image, label = vdata.next()
                input_valid["image"].d = image
                pred_valid.forward()
                vdata_pred[j:min(j + args.batch_size, n_valid_samples
                                 )] = pred_valid.d[:min(
                                     args.batch_size, n_valid_samples - j)]
                vdata_label[j:min(j + args.batch_size, n_valid_samples
                                  )] = label[:min(args.
                                                  batch_size, n_valid_samples -
                                                  j)]
            ve = categorical_error(vdata_pred, vdata_label)
            monitor_verr.add(i + 1, ve)

        if int((i + 1) % args.model_save_interval) == 0:
            nn.save_parameters(
                os.path.join(args.save_path, 'params_%06d.h5' % (i + 1)))

        # Monitering
        monitor_loss.add(i + 1, loss_train.d.copy())
        monitor_time.add(i + 1)

    nn.save_parameters(
        os.path.join(args.save_path, 'params_%06d.h5' % (args.max_iter)))
Ejemplo n.º 25
0
def train():
    rng = np.random.RandomState(803)

    conf = get_config()

    comm = init_nnabla(conf)

    # create data iterator
    if conf.dataset == "cityscapes":
        data_list = get_cityscape_datalist(conf.cityscapes,
                                           save_file=comm.rank == 0)
        n_class = conf.cityscapes.n_label_ids
        use_inst = True

        data_iter = create_cityscapes_iterator(conf.batch_size,
                                               data_list,
                                               comm=comm,
                                               image_shape=conf.image_shape,
                                               rng=rng,
                                               flip=conf.use_flip)

    elif conf.dataset == "ade20k":
        data_list = get_ade20k_datalist(conf.ade20k, save_file=comm.rank == 0)
        n_class = conf.ade20k.n_label_ids + 1  # class id + unknown
        use_inst = False

        load_shape = tuple(
            x + 30
            for x in conf.image_shape) if conf.use_crop else conf.image_shape
        data_iter = create_ade20k_iterator(conf.batch_size,
                                           data_list,
                                           comm=comm,
                                           load_shape=load_shape,
                                           crop_shape=conf.image_shape,
                                           rng=rng,
                                           flip=conf.use_flip)

    else:
        raise NotImplementedError(
            "Currently dataset {} is not supported.".format(conf.dataset))

    real = nn.Variable(shape=(conf.batch_size, 3) + conf.image_shape)
    obj_mask = nn.Variable(shape=(conf.batch_size, ) + conf.image_shape)

    if use_inst:
        ist_mask = nn.Variable(shape=(conf.batch_size, ) + conf.image_shape)
        obj_onehot, bm = encode_inputs(ist_mask, obj_mask, n_ids=n_class)
        mask = F.concatenate(obj_onehot, bm, axis=1)
    else:
        om = obj_mask
        if len(om.shape) == 3:
            om = F.reshape(om, om.shape + (1, ))
        obj_onehot = F.one_hot(om, shape=(n_class, ))
        mask = F.transpose(obj_onehot, (0, 3, 1, 2))

    # generator
    generator = SpadeGenerator(conf.g_ndf, image_shape=conf.image_shape)
    z = F.randn(shape=(conf.batch_size, conf.z_dim))
    fake = generator(z, mask)

    # unlinking
    ul_mask, ul_fake = get_unlinked_all(mask, fake)

    # discriminator
    discriminator = PatchGAN(n_scales=conf.d_n_scales)
    d_input_real = F.concatenate(real, ul_mask, axis=1)
    d_input_fake = F.concatenate(ul_fake, ul_mask, axis=1)
    d_real_out, d_real_feats = discriminator(d_input_real)
    d_fake_out, d_fake_feats = discriminator(d_input_fake)

    g_gan, g_feat, d_real, d_fake = discriminator.get_loss(
        d_real_out,
        d_real_feats,
        d_fake_out,
        d_fake_feats,
        use_fm=conf.use_fm,
        fm_lambda=conf.lambda_fm,
        gan_loss_type=conf.gan_loss_type)

    def _rescale(x):
        return rescale_values(x,
                              input_min=-1,
                              input_max=1,
                              output_min=0,
                              output_max=255)

    g_vgg = vgg16_perceptual_loss(_rescale(ul_fake),
                                  _rescale(real)) * conf.lambda_vgg

    set_persistent_all(fake, mask, g_gan, g_feat, d_real, d_fake, g_vgg)

    # loss
    g_loss = g_gan + g_feat + g_vgg
    d_loss = (d_real + d_fake) / 2

    # load params
    if conf.load_params is not None:
        print("load parameters from {}".format(conf.load_params))
        nn.load_parameters(conf.load_params)

    # Setup Solvers
    g_solver = S.Adam(beta1=0.)
    g_solver.set_parameters(get_params_startswith("spade_generator"))

    d_solver = S.Adam(beta1=0.)
    d_solver.set_parameters(get_params_startswith("discriminator"))

    # lr scheduler
    g_lrs = LinearDecayScheduler(start_lr=conf.g_lr,
                                 end_lr=0.,
                                 start_iter=100,
                                 end_iter=200)
    d_lrs = LinearDecayScheduler(start_lr=conf.d_lr,
                                 end_lr=0.,
                                 start_iter=100,
                                 end_iter=200)

    ipe = get_iteration_per_epoch(data_iter._size,
                                  conf.batch_size,
                                  round="ceil")

    if not conf.show_interval:
        conf.show_interval = ipe
    if not conf.save_interval:
        conf.save_interval = ipe
    if not conf.niter:
        conf.niter = 200 * ipe

    # Setup Reporter
    losses = {
        "g_gan": g_gan,
        "g_feat": g_feat,
        "g_vgg": g_vgg,
        "d_real": d_real,
        "d_fake": d_fake
    }
    reporter = Reporter(comm,
                        losses,
                        conf.save_path,
                        nimage_per_epoch=min(conf.batch_size, 5),
                        show_interval=10)
    progress_iterator = trange(conf.niter, disable=comm.rank > 0)
    reporter.start(progress_iterator)

    colorizer = Colorize(n_class)

    # output all config and dump to file
    if comm.rank == 0:
        conf.dump_to_stdout()
        write_yaml(os.path.join(conf.save_path, "config.yaml"), conf)

    epoch = 0
    for itr in progress_iterator:
        if itr % ipe == 0:
            g_lr = g_lrs(epoch)
            d_lr = d_lrs(epoch)
            g_solver.set_learning_rate(g_lr)
            d_solver.set_learning_rate(d_lr)
            if comm.rank == 0:
                print(
                    "\n[epoch {}] update lr to ... g_lr: {}, d_lr: {}".format(
                        epoch, g_lr, d_lr))

            epoch += 1

        if conf.dataset == "cityscapes":
            im, ist, obj = data_iter.next()
            ist_mask.d = ist
        elif conf.dataset == "ade20k":
            im, obj = data_iter.next()
        else:
            raise NotImplemented()

        real.d = im
        obj_mask.d = obj

        # text embedding and create fake
        fake.forward()

        # update discriminator
        d_solver.zero_grad()
        d_loss.forward()
        d_loss.backward(clear_buffer=True)
        comm.all_reduced_solver_update(d_solver)

        # update generator
        ul_fake.grad.zero()
        g_solver.zero_grad()
        g_loss.forward()
        g_loss.backward(clear_buffer=True)

        # backward generator
        fake.backward(grad=None, clear_buffer=True)
        comm.all_reduced_solver_update(g_solver)

        # report iteration progress
        reporter()

        # report epoch progress
        show_epoch = itr // conf.show_interval
        if (itr % conf.show_interval) == 0:
            show_images = {
                "RealImages": real.data.get_data("r").transpose((0, 2, 3, 1)),
                "ObjectMask": colorizer(obj).astype(np.uint8),
                "GeneratedImage": fake.data.get_data("r").transpose(
                    (0, 2, 3, 1))
            }

            reporter.step(show_epoch, show_images)

        if (itr % conf.save_interval) == 0 and comm.rank == 0:
            nn.save_parameters(
                os.path.join(conf.save_path,
                             'param_{:03d}.h5'.format(show_epoch)))

    if comm.rank == 0:
        nn.save_parameters(os.path.join(conf.save_path, 'param_final.h5'))
Ejemplo n.º 26
0
def train():
    '''
    Main script.
    '''
    args = get_args()

    from numpy.random import seed
    seed(0)

    # Get context.
    from nnabla.contrib.context import extension_context
    extension_module = args.context
    if args.context is None:
        extension_module = 'cpu'
    logger.info("Running in %s" % extension_module)
    ctx = extension_context(extension_module, device_id=args.device_id)
    nn.set_default_context(ctx)

    # TRAIN
    image = nn.Variable([args.batch_size, 1, 28, 28])
    label = nn.Variable([args.batch_size, 1])
    x = image / 255.0
    t_onehot = F.one_hot(label, (10, ))
    with nn.parameter_scope("capsnet"):
        c1, pcaps, u_hat, caps, pred = model.capsule_net(
            x,
            test=False,
            aug=True,
            grad_dynamic_routing=args.grad_dynamic_routing)
    with nn.parameter_scope("capsnet_reconst"):
        recon = model.capsule_reconstruction(caps, t_onehot)
    loss_margin, loss_reconst, loss = model.capsule_loss(
        pred, t_onehot, recon, x)
    pred.persistent = True

    # TEST
    # Create input variables.
    vimage = nn.Variable([args.batch_size, 1, 28, 28])
    vlabel = nn.Variable([args.batch_size, 1])
    vx = vimage / 255.0
    with nn.parameter_scope("capsnet"):
        _, _, _, _, vpred = model.capsule_net(vx, test=True, aug=False)

    # Create Solver.
    solver = S.Adam(args.learning_rate)
    solver.set_parameters(nn.get_parameters())

    # Create monitor.
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    train_iter = int(60000 / args.batch_size)
    val_iter = int(10000 / args.batch_size)
    logger.info("#Train: {} #Validation: {}".format(train_iter, val_iter))
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=1)
    monitor_mloss = MonitorSeries("Training margin loss", monitor, interval=1)
    monitor_rloss = MonitorSeries("Training reconstruction loss",
                                  monitor,
                                  interval=1)
    monitor_err = MonitorSeries("Training error", monitor, interval=1)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=1)
    monitor_verr = MonitorSeries("Test error", monitor, interval=1)
    monitor_lr = MonitorSeries("Learning rate", monitor, interval=1)

    # Initialize DataIterator for MNIST.
    from numpy.random import RandomState
    data = data_iterator_mnist(args.batch_size, True, rng=RandomState(1223))
    vdata = data_iterator_mnist(args.batch_size, False)

    # Training loop.
    for e in range(args.max_epochs):

        # Learning rate decay
        learning_rate = solver.learning_rate()
        if e != 0:
            learning_rate *= 0.9
        solver.set_learning_rate(learning_rate)
        monitor_lr.add(e, learning_rate)

        # Training
        train_error = 0.0
        train_loss = 0.0
        train_mloss = 0.0
        train_rloss = 0.0
        for i in range(train_iter):
            image.d, label.d = data.next()
            solver.zero_grad()
            loss.forward(clear_no_need_grad=True)
            loss.backward(clear_buffer=True)
            solver.update()
            train_error += categorical_error(pred.d, label.d)
            train_loss += loss.d
            train_mloss += loss_margin.d
            train_rloss += loss_reconst.d
        train_error /= train_iter
        train_loss /= train_iter
        train_mloss /= train_iter
        train_rloss /= train_iter

        # Validation
        val_error = 0.0
        for j in range(val_iter):
            vimage.d, vlabel.d = vdata.next()
            vpred.forward(clear_buffer=True)
            val_error += categorical_error(vpred.d, vlabel.d)
        val_error /= val_iter

        # Monitor
        monitor_time.add(e)
        monitor_loss.add(e, train_loss)
        monitor_mloss.add(e, train_mloss)
        monitor_rloss.add(e, train_rloss)
        monitor_err.add(e, train_error)
        monitor_verr.add(e, val_error)
        nn.save_parameters(
            os.path.join(args.monitor_path, 'params_%06d.h5' % e))
Ejemplo n.º 27
0
def main(args):
    # Settings
    device_id = args.device_id
    batch_size = args.batch_size
    batch_size_eval = args.batch_size_eval
    n_l_train_data = 4000
    n_train_data = 50000
    n_cls = 10
    learning_rate = 1. * 1e-3
    n_epoch = 300
    act = F.relu
    iter_epoch = int(n_train_data / batch_size)
    n_iter = n_epoch * iter_epoch
    extension_module = args.context
    alpha = args.alpha

    # Supervised Model 
    ## ERM
    batch_size, m, h, w = batch_size, 3, 32, 32
    ctx = extension_context(extension_module, device_id=device_id)
    x_l_0 = nn.Variable((batch_size, m, h, w))
    y_l_0 = nn.Variable((batch_size, 1))
    pred = cnn_model_003(ctx, x_l_0)
    loss_ce = ce_loss(ctx, pred, y_l_0)
    loss_er = er_loss(ctx, pred)
    loss_supervised = loss_ce + loss_er
    ## VRM (mixup)
    x_l_1 = nn.Variable((batch_size, m, h, w))
    y_l_1 = nn.Variable((batch_size, 1))
    coef = nn.Variable()
    coef_b = F.broadcast(coef.reshape([1]*x_l_0.ndim, unlink=True), x_l_0.shape)
    x_l_m = coef_b * x_l_0 + (1 - coef_b) * x_l_1
    coef_b = F.broadcast(coef.reshape([1]*pred.ndim, unlink=True), pred.shape)
    y_l_m = coef_b * F.one_hot(y_l_0, (n_cls, )) \
            + (1-coef_b) * F.one_hot(y_l_1, (n_cls, ))
    x_l_m.need_grad, y_l_m.need_grad = False, False
    pred_m = cnn_model_003(ctx, x_l_m)
    loss_er_m = er_loss(ctx, pred_m)  #todo: need?
    loss_ce_m = ce_loss_soft(ctx, pred, y_l_m)
    loss_supervised_m = loss_ce_m #+ loss_er_m
    
    # Semi-Supervised Model
    ## ERM
    x_u0 = nn.Variable((batch_size, m, h, w))
    x_u1 = nn.Variable((batch_size, m, h, w))
    pred_x_u0 = cnn_model_003(ctx, x_u0)
    pred_x_u1 = cnn_model_003(ctx, x_u1)
    pred_x_u0.persistent, pred_x_u1.persistent = True, True
    loss_sr = sr_loss(ctx, pred_x_u0, pred_x_u1)
    loss_er0 = er_loss(ctx, pred_x_u0)
    loss_er1 = er_loss(ctx, pred_x_u1)
    loss_unsupervised = loss_sr + loss_er0 + loss_er1
    ## VRM (mixup)
    x_u2 = nn.Variable((batch_size, m, h, w))  # not to overwrite x_u1.d
    coef_u = nn.Variable()
    coef_u_b = F.broadcast(coef_u.reshape([1]*x_u0.ndim, unlink=True), x_u0.shape)
    x_u_m = coef_u_b * x_u0 + (1-coef_u_b) * x_u2
    pred_x_u0_ = nn.Variable(pred_x_u0.shape)  # unlink forward pass but reuse result
    pred_x_u1_ = nn.Variable(pred_x_u1.shape)
    pred_x_u0_.data = pred_x_u0.data
    pred_x_u1_.data = pred_x_u1.data
    coef_u_b = F.broadcast(coef_u.reshape([1]*pred_x_u0.ndim, unlink=True), pred_x_u0.shape)
    y_u_m = coef_u_b * pred_x_u0_ + (1-coef_u_b) * pred_x_u1_
    x_u_m.need_grad, y_u_m.need_grad = False, False
    pred_x_u_m = cnn_model_003(ctx, x_u_m)
    loss_er_u_m = er_loss(ctx, pred_x_u_m)  #todo: need?
    loss_ce_u_m = ce_loss_soft(ctx, pred_x_u_m, y_u_m)
    loss_unsupervised_m = loss_ce_u_m #+ loss_er_u_m
    
    # Evaluatation Model
    batch_size_eval, m, h, w = batch_size, 3, 32, 32
    x_eval = nn.Variable((batch_size_eval, m, h, w))
    pred_eval = cnn_model_003(ctx, x_eval, test=True)
    
    # Solver
    with nn.context_scope(ctx):
        solver = S.Adam(alpha=learning_rate)
        solver.set_parameters(nn.get_parameters())

    # Dataset
    ## separate dataset
    home = os.environ.get("HOME")
    fpath = os.path.join(home, "datasets/cifar10/cifar-10.npz")
    separator = Separator(n_l_train_data)
    separator.separate_then_save(fpath)

    l_train_path = os.path.join(home, "datasets/cifar10/l_cifar-10.npz")
    u_train_path = os.path.join(home, "datasets/cifar10/cifar-10.npz")
    test_path = os.path.join(home, "datasets/cifar10/cifar-10.npz")

    # data reader
    data_reader = Cifar10DataReader(l_train_path, u_train_path, test_path,
                                  batch_size=batch_size,
                                  n_cls=n_cls,
                                  da=True,
                                  shape=True)

    # Training loop
    print("# Training loop")
    epoch = 1
    st = time.time()
    acc_prev = 0.
    ve_best = 1.
    save_path_prev = ""
    for i in range(n_iter):
        # Get data and set it to the varaibles
        x_l0_data, x_l1_data, y_l_data = data_reader.get_l_train_batch()
        x_u0_data, x_u1_data, y_u_data = data_reader.get_u_train_batch()
        
        x_l_0.d, _ , y_l_0.d= x_l0_data, x_l1_data, y_l_data
        x_u0.d, x_u1.d= x_u0_data, x_u1_data

        # Train
        ## forward (supervised and its mixup)
        loss_supervised.forward(clear_no_need_grad=True)
        coef_data = np.random.beta(alpha, alpha)
        coef.d = coef_data
        x_l_1.d = np.random.permutation(x_l0_data)
        y_l_1.d = np.random.permutation(y_l_data)
        loss_supervised_m.forward(clear_no_need_grad=True)
        ## forward (unsupervised and its mixup)
        loss_unsupervised.forward(clear_no_need_grad=True)
        coef_data = np.random.beta(alpha, alpha)
        coef_u.d = coef_data
        x_u2.d = np.random.permutation(x_u1_data)
        loss_unsupervised_m.forward(clear_no_need_grad=True)
        
        ## backward
        solver.zero_grad()
        loss_supervised.backward(clear_buffer=False)
        loss_supervised_m.backward(clear_buffer=False)
        loss_unsupervised.backward(clear_buffer=False)
        loss_unsupervised_m.backward(clear_buffer=True)
        solver.update()
        
        # Evaluate
        if int((i+1) % iter_epoch) == 0:
            # Get data and set it to the varaibles
            x_data, y_data = data_reader.get_test_batch()

            # Evaluation loop
            ve = 0.
            iter_val = 0
            for k in range(0, len(x_data), batch_size_eval):
                x_eval.d = get_test_data(x_data, k, batch_size_eval)
                label = get_test_data(y_data, k, batch_size_eval)
                pred_eval.forward(clear_buffer=True)
                ve += categorical_error(pred_eval.d, label)
                iter_val += 1
            ve /= iter_val                
            msg = "Epoch:{},ElapsedTime:{},Acc:{:02f}".format(
                epoch,
                time.time() - st, 
                (1. - ve) * 100)
            print(msg)
            if ve < ve_best:
                if not os.path.exists(args.model_save_path):
                    os.makedirs(args.model_save_path)
                if save_path_prev != "":
                    os.remove(save_path_prev)
                save_path = os.path.join(
                    args.model_save_path, 'params_%06d.h5' % epoch)
                nn.save_parameters(save_path)
                save_path_prev = save_path
                ve_best = ve
            st = time.time()
            epoch +=1
Ejemplo n.º 28
0
def train():
    if Config.USE_NW:
        env = Environment('Pong-v0')
    else:
        env = gym.make('Pong-v0')

    extension_module = Config.CONTEXT
    logger.info("Running in {}".format(extension_module))
    ctx = extension_context(extension_module, device_id=Config.DEVICE_ID)
    nn.set_default_context(ctx)

    monitor = Monitor(Config.MONITOR_PATH)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=1)
    monitor_reward = MonitorSeries("Training reward", monitor, interval=1)
    monitor_q = MonitorSeries("Training q", monitor, interval=1)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=1)

    # placeholder
    image = nn.Variable([
        Config.BATCH_SIZE, Config.STATE_LENGTH, Config.FRAME_WIDTH,
        Config.FRAME_HEIGHT
    ])
    image_target = nn.Variable([
        Config.BATCH_SIZE, Config.STATE_LENGTH, Config.FRAME_WIDTH,
        Config.FRAME_HEIGHT
    ])

    nn.clear_parameters()

    # create network
    with nn.parameter_scope("dqn"):
        q = dqn(image, test=False)
        q.prersistent = True  # Not to clear at backward
    with nn.parameter_scope("target"):
        target_q = dqn(image_target, test=False)
        target_q.prersistent = True  # Not to clear at backward

    # loss definition
    a = nn.Variable([Config.BATCH_SIZE, 1])
    q_val = F.sum(F.one_hot(a, (6, )) * q, axis=1, keepdims=True)
    t = nn.Variable([Config.BATCH_SIZE, 1])
    loss = F.mean(F.squared_error(t, q_val))

    if Config.RESUME:
        logger.info('load model: {}'.format(Config.RESUME))
        nn.load_parameters(Config.RESUME)

    # setup solver
    # update dqn parameter only
    solver = S.RMSprop(lr=Config.LEARNING_RATE,
                       decay=Config.DECAY,
                       eps=Config.EPSILON)
    with nn.parameter_scope("dqn"):
        solver.set_parameters(nn.get_parameters())

    # training
    epsilon = Config.INIT_EPSILON
    experiences = []
    step = 0
    for i in range(Config.EPISODE_LENGTH):
        logger.info("EPISODE {}".format(i))
        done = False
        observation = env.reset()
        for i in range(30):
            observation_next, reward, done, info = env.step(0)
            observation_next = preprocess_frame(observation_next)
        # join 4 frame
        state = [observation_next for _ in xrange(Config.STATE_LENGTH)]
        state = np.stack(state, axis=0)
        total_reward = 0
        while not done:
            # select action
            if step % Config.ACTION_INTERVAL == 0:
                if random.random() > epsilon or len(
                        experiences) >= Config.REPLAY_MEMORY_SIZE:
                    # inference
                    image.d = state
                    q.forward()
                    action = np.argmax(q.d)
                else:
                    # random action
                    if Config.USE_NW:
                        action = env.sample()
                    else:
                        action = env.action_space.sample()  # TODO refactor
                if epsilon > Config.MIN_EPSILON:
                    epsilon -= Config.EPSILON_REDUCTION_PER_STEP

            # get next environment
            observation_next, reward, done, info = env.step(action)
            observation_next = preprocess_frame(observation_next)
            total_reward += reward
            # TODO clip reward

            # update replay memory (FIFO)
            state_next = np.append(state[1:, :, :],
                                   observation_next[np.newaxis, :, :],
                                   axis=0)
            experiences.append((state_next, reward, action, state, done))
            if len(experiences) > Config.REPLAY_MEMORY_SIZE:
                experiences.pop(0)

            # update network
            if step % Config.NET_UPDATE_INTERVAL == 0 and len(
                    experiences) > Config.INIT_REPLAY_SIZE:
                logger.info("update {}".format(step))
                batch = random.sample(experiences, Config.BATCH_SIZE)
                batch_observation_next = np.array([b[0] for b in batch])
                batch_reward = np.array([b[1] for b in batch])
                batch_action = np.array([b[2] for b in batch])
                batch_observation = np.array([b[3] for b in batch])
                batch_done = np.array([b[4] for b in batch], dtype=np.float32)

                batch_reward = batch_reward[:, np.newaxis]
                batch_action = batch_action[:, np.newaxis]
                batch_done = batch_done[:, np.newaxis]

                image.d = batch_observation.astype(np.float32)
                image_target.d = batch_observation_next.astype(np.float32)
                a.d = batch_action
                q_val.forward()  # XXX
                target_q.forward()
                t.d = batch_reward + (1 - batch_done) * Config.GAMMA * np.max(
                    target_q.d, axis=1, keepdims=True)
                solver.zero_grad()
                loss.forward()
                loss.backward()

                monitor_loss.add(step, loss.d.copy())
                monitor_reward.add(step, total_reward)
                monitor_q.add(step, np.mean(q.d.copy()))
                monitor_time.add(step)
                # TODO weight clip
                solver.update()
                logger.info("update done {}".format(step))

            # update target network
            if step % Config.TARGET_NET_UPDATE_INTERVAL == 0:
                # copy parameter from dqn to target
                with nn.parameter_scope("dqn"):
                    src = nn.get_parameters()
                with nn.parameter_scope("target"):
                    dst = nn.get_parameters()
                for (s_key, s_val), (d_key,
                                     d_val) in zip(src.items(), dst.items()):
                    # Variable#d method is reference
                    d_val.d = s_val.d.copy()

            if step % Config.MODEL_SAVE_INTERVAL == 0:
                logger.info("save model")
                nn.save_parameters("model_{}.h5".format(step))

            step += 1
            observation = observation_next
            state = state_next
Ejemplo n.º 29
0
    def _build(self):
        # infer variable
        self.infer_obs_t = infer_obs_t = nn.Variable((1, 4, 84, 84))
        # inference output
        self.infer_q_t,\
        self.infer_probs_t, _ = self.q_function(infer_obs_t, self.num_actions,
                                                self.min_v, self.max_v,
                                                self.num_bins, 'q_func')
        self.infer_t = F.sink(self.infer_q_t, self.infer_probs_t)

        # train variables
        self.obss_t = nn.Variable((self.batch_size, 4, 84, 84))
        self.acts_t = nn.Variable((self.batch_size, 1))
        self.rews_tp1 = nn.Variable((self.batch_size, 1))
        self.obss_tp1 = nn.Variable((self.batch_size, 4, 84, 84))
        self.ters_tp1 = nn.Variable((self.batch_size, 1))

        # training output
        q_t, probs_t, dists = self.q_function(self.obss_t, self.num_actions,
                                              self.min_v, self.max_v,
                                              self.num_bins, 'q_func')
        q_tp1, probs_tp1, _ = self.q_function(self.obss_tp1, self.num_actions,
                                              self.min_v, self.max_v,
                                              self.num_bins, 'target_q_func')

        expand_last = lambda x: F.reshape(x, x.shape + (1, ))
        flat = lambda x: F.reshape(x, (-1, 1))

        # extract selected dimension
        a_t_one_hot = expand_last(F.one_hot(self.acts_t, (self.num_actions, )))
        probs_t_selected = F.max(probs_t * a_t_one_hot, axis=1)
        # extract max dimension
        _, indices = F.max(q_tp1, axis=1, keepdims=True, with_index=True)
        a_tp1_one_hot = expand_last(F.one_hot(indices, (self.num_actions, )))
        probs_tp1_best = F.max(probs_tp1 * a_tp1_one_hot, axis=1)

        # clipping reward
        clipped_rews_tp1 = clip_by_value(self.rews_tp1, -1.0, 1.0)

        disc_q_tp1 = F.reshape(dists, (1, -1)) * (1.0 - self.ters_tp1)
        t_z = clip_by_value(clipped_rews_tp1 + self.gamma * disc_q_tp1,
                            self.min_v, self.max_v)

        # update indices
        b = (t_z - self.min_v) / ((self.max_v - self.min_v) /
                                  (self.num_bins - 1))
        l = F.floor(b)
        l_mask = F.reshape(F.one_hot(flat(l), (self.num_bins, )),
                           (-1, self.num_bins, self.num_bins))
        u = F.ceil(b)
        u_mask = F.reshape(F.one_hot(flat(u), (self.num_bins, )),
                           (-1, self.num_bins, self.num_bins))

        m_l = expand_last(probs_tp1_best * (1 - (b - l)))
        m_u = expand_last(probs_tp1_best * (b - l))
        m = F.sum(m_l * l_mask + m_u * u_mask, axis=1)
        m.need_grad = False

        self.loss = -F.mean(F.sum(m * F.log(probs_t_selected + 1e-10), axis=1))

        # optimizer
        self.solver = S.RMSprop(self.lr, 0.95, 1e-2)

        # weights and biases
        with nn.parameter_scope('q_func'):
            self.params = nn.get_parameters()
        with nn.parameter_scope('target_q_func'):
            self.target_params = nn.get_parameters()

        # set q function parameters to solver
        self.solver.set_parameters(self.params)
Ejemplo n.º 30
0
def main(args):
    # Settings
    device_id = args.device_id
    batch_size = args.batch_size
    batch_size_eval = args.batch_size_eval
    n_l_train_data = 4000
    n_train_data = 50000
    n_cls = 10
    learning_rate = 1. * 1e-3
    n_epoch = 300
    act = F.relu
    iter_epoch = int(n_train_data / batch_size)
    n_iter = n_epoch * iter_epoch
    extension_module = args.context
    alpha = args.alpha

    # Supervised Model
    ## ERM
    batch_size, m, h, w = batch_size, 3, 32, 32
    ctx = extension_context(extension_module, device_id=device_id)
    x_l_0 = nn.Variable((batch_size, m, h, w))
    y_l_0 = nn.Variable((batch_size, 1))
    pred = cnn_model_003(ctx, x_l_0)
    loss_ce = ce_loss(ctx, pred, y_l_0)
    loss_er = er_loss(ctx, pred)
    loss_supervised = loss_ce + loss_er
    ## VRM (mixup)
    x_l_1 = nn.Variable((batch_size, m, h, w))
    y_l_1 = nn.Variable((batch_size, 1))
    coef = nn.Variable()
    coef_b = F.broadcast(coef.reshape([1] * x_l_0.ndim, unlink=True),
                         x_l_0.shape)
    x_l_m = coef_b * x_l_0 + (1 - coef_b) * x_l_1
    coef_b = F.broadcast(coef.reshape([1] * pred.ndim, unlink=True),
                         pred.shape)
    y_l_m = coef_b * F.one_hot(y_l_0, (n_cls, )) \
            + (1-coef_b) * F.one_hot(y_l_1, (n_cls, ))
    x_l_m.need_grad, y_l_m.need_grad = False, False
    pred_m = cnn_model_003(ctx, x_l_m)
    loss_er_m = er_loss(ctx, pred_m)  #todo: need?
    loss_ce_m = ce_loss_soft(ctx, pred, y_l_m)
    loss_supervised_m = loss_ce_m  #+ loss_er_m

    # Semi-Supervised Model
    ## ERM
    x_u0 = nn.Variable((batch_size, m, h, w))
    x_u1 = nn.Variable((batch_size, m, h, w))
    pred_x_u0 = cnn_model_003(ctx, x_u0)
    pred_x_u1 = cnn_model_003(ctx, x_u1)
    pred_x_u0.persistent, pred_x_u1.persistent = True, True
    loss_sr = sr_loss(ctx, pred_x_u0, pred_x_u1)
    loss_er0 = er_loss(ctx, pred_x_u0)
    loss_er1 = er_loss(ctx, pred_x_u1)
    loss_unsupervised = loss_sr + loss_er0 + loss_er1
    ## VRM (mixup)
    x_u2 = nn.Variable((batch_size, m, h, w))  # not to overwrite x_u1.d
    coef_u = nn.Variable()
    coef_u_b = F.broadcast(coef_u.reshape([1] * x_u0.ndim, unlink=True),
                           x_u0.shape)
    x_u_m = coef_u_b * x_u0 + (1 - coef_u_b) * x_u2
    pred_x_u0_ = nn.Variable(
        pred_x_u0.shape)  # unlink forward pass but reuse result
    pred_x_u1_ = nn.Variable(pred_x_u1.shape)
    pred_x_u0_.data = pred_x_u0.data
    pred_x_u1_.data = pred_x_u1.data
    coef_u_b = F.broadcast(coef_u.reshape([1] * pred_x_u0.ndim, unlink=True),
                           pred_x_u0.shape)
    y_u_m = coef_u_b * pred_x_u0_ + (1 - coef_u_b) * pred_x_u1_
    x_u_m.need_grad, y_u_m.need_grad = False, False
    pred_x_u_m = cnn_model_003(ctx, x_u_m)
    loss_er_u_m = er_loss(ctx, pred_x_u_m)  #todo: need?
    loss_ce_u_m = ce_loss_soft(ctx, pred_x_u_m, y_u_m)
    loss_unsupervised_m = loss_ce_u_m  #+ loss_er_u_m

    # Evaluatation Model
    batch_size_eval, m, h, w = batch_size, 3, 32, 32
    x_eval = nn.Variable((batch_size_eval, m, h, w))
    pred_eval = cnn_model_003(ctx, x_eval, test=True)

    # Solver
    with nn.context_scope(ctx):
        solver = S.Adam(alpha=learning_rate)
        solver.set_parameters(nn.get_parameters())

    # Dataset
    ## separate dataset
    home = os.environ.get("HOME")
    fpath = os.path.join(home, "datasets/cifar10/cifar-10.npz")
    separator = Separator(n_l_train_data)
    separator.separate_then_save(fpath)

    l_train_path = os.path.join(home, "datasets/cifar10/l_cifar-10.npz")
    u_train_path = os.path.join(home, "datasets/cifar10/cifar-10.npz")
    test_path = os.path.join(home, "datasets/cifar10/cifar-10.npz")

    # data reader
    data_reader = Cifar10DataReader(l_train_path,
                                    u_train_path,
                                    test_path,
                                    batch_size=batch_size,
                                    n_cls=n_cls,
                                    da=True,
                                    shape=True)

    # Training loop
    print("# Training loop")
    epoch = 1
    st = time.time()
    acc_prev = 0.
    ve_best = 1.
    save_path_prev = ""
    for i in range(n_iter):
        # Get data and set it to the varaibles
        x_l0_data, x_l1_data, y_l_data = data_reader.get_l_train_batch()
        x_u0_data, x_u1_data, y_u_data = data_reader.get_u_train_batch()

        x_l_0.d, _, y_l_0.d = x_l0_data, x_l1_data, y_l_data
        x_u0.d, x_u1.d = x_u0_data, x_u1_data

        # Train
        ## forward (supervised and its mixup)
        loss_supervised.forward(clear_no_need_grad=True)
        coef_data = np.random.beta(alpha, alpha)
        coef.d = coef_data
        x_l_1.d = np.random.permutation(x_l0_data)
        y_l_1.d = np.random.permutation(y_l_data)
        loss_supervised_m.forward(clear_no_need_grad=True)
        ## forward (unsupervised and its mixup)
        loss_unsupervised.forward(clear_no_need_grad=True)
        coef_data = np.random.beta(alpha, alpha)
        coef_u.d = coef_data
        x_u2.d = np.random.permutation(x_u1_data)
        loss_unsupervised_m.forward(clear_no_need_grad=True)

        ## backward
        solver.zero_grad()
        loss_supervised.backward(clear_buffer=False)
        loss_supervised_m.backward(clear_buffer=False)
        loss_unsupervised.backward(clear_buffer=False)
        loss_unsupervised_m.backward(clear_buffer=True)
        solver.update()

        # Evaluate
        if int((i + 1) % iter_epoch) == 0:
            # Get data and set it to the varaibles
            x_data, y_data = data_reader.get_test_batch()

            # Evaluation loop
            ve = 0.
            iter_val = 0
            for k in range(0, len(x_data), batch_size_eval):
                x_eval.d = get_test_data(x_data, k, batch_size_eval)
                label = get_test_data(y_data, k, batch_size_eval)
                pred_eval.forward(clear_buffer=True)
                ve += categorical_error(pred_eval.d, label)
                iter_val += 1
            ve /= iter_val
            msg = "Epoch:{},ElapsedTime:{},Acc:{:02f}".format(
                epoch,
                time.time() - st, (1. - ve) * 100)
            print(msg)
            if ve < ve_best:
                if not os.path.exists(args.model_save_path):
                    os.makedirs(args.model_save_path)
                if save_path_prev != "":
                    os.remove(save_path_prev)
                save_path = os.path.join(args.model_save_path,
                                         'params_%06d.h5' % epoch)
                nn.save_parameters(save_path)
                save_path_prev = save_path
                ve_best = ve
            st = time.time()
            epoch += 1