def loss(target_action, target_action_type, target_action_mask, rule_prob, terminal_gen_action_prob, token_prob, copy_prob): batch_size, max_action_length, _ = target_action.shape _, _, rule_num = rule_prob.shape _, _, token_num = token_prob.shape _, _, max_query_length = copy_prob.shape # (batch_size, max_action_length) target_rule, target_token, target_copy = F.split(target_action, axis=2) target_rule = F.reshape(target_rule, (batch_size, max_action_length, 1)) target_rule = F.one_hot( target_rule, (rule_num, )) # (batch_size, max_action_length, rule_num) rule_tgt_prob = rule_prob * target_rule # (batch_size, max_action_length, rule_num) rule_tgt_prob = F.sum(rule_tgt_prob, axis=2) # (batch_size, max_action_length) target_token = F.reshape(target_token, (batch_size, max_action_length, 1)) target_token = F.one_hot( target_token, (token_num, )) # (batch_size, max_action_length, token_num) token_tgt_prob = token_prob * target_token # (batch_size, max_action_length, token_num) token_tgt_prob = F.sum(token_tgt_prob, axis=2) # (batch_size, max_action_length) target_copy = F.reshape(target_copy, (batch_size, max_action_length, 1)) target_copy = F.one_hot( target_copy, (max_query_length, )) # (batch_size, max_action_length, max_query_lenght) copy_tgt_prob = copy_prob * target_copy # (batch_size, max_action_length, max_query_length) copy_tgt_prob = F.sum(copy_tgt_prob, axis=2) # (batch_size, max_action_length) # (batch_size, max_action_length) gen_token_prob, copy_token_prob = F.split(terminal_gen_action_prob, axis=2) # (batch_size, max_action_length) rule_mask, token_mask, copy_mask = F.split(target_action_type, axis=2) # (batch_size, max_action_length) target_prob = rule_mask * rule_tgt_prob + \ token_mask * gen_token_prob * token_tgt_prob + \ copy_mask * copy_token_prob * copy_tgt_prob # (batch_size, max_action_length) likelihood = F.log(target_prob + 1e-7) loss = -likelihood * target_action_mask # (batch_size) loss = F.sum(loss, axis=1) return F.mean(loss)
def softmax_cross_entropy_backward(inputs, axis=None): """ Args: inputs (list of nn.Variable): Incomming grads/inputs to/of the forward function. kwargs (dict of arguments): Dictionary of the corresponding function arguments. Return: list of Variable: Return the gradients wrt inputs of the corresponding function. """ dy = inputs[0] x0 = inputs[1] t0 = inputs[2] D = len(x0.shape) axis = positive_axis(axis, D) c0 = x0.shape[axis] t0_shape = [s for s in t0.shape if s != 1] u0 = F.reshape(t0, (-1, 1), inplace=False) u1 = F.one_hot(u0, (c0, )) to = F.reshape(u1, t0_shape + [ c0, ]) t0 = no_grad(to) if axis != len(to.shape) - 1: oaxes = [i for i in range(len(t0_shape))] taxes = oaxes[:axis] + [to.ndim - 1] + oaxes[axis:] to = F.transpose(to, taxes) dx0 = dy * (F.softmax(x0, axis=axis) - to) return dx0, None
def mix_data(self, image, label): ''' Define mixed data Variables. Args: image(Variable): (B, C, H, W) or (B, H, W, C) label(Variable): (B, 1) of integers in [0, num_classes) Returns: image(Variable): mixed data label(Variable): mixed label with (B, num_clases) ''' if image.shape[0] % 2 != 0: raise ValueError( 'Please use an even number of batch size with this implementation of mixup regularization. Given {}.' .format(image.shape[0])) image2 = image[::-1] label = F.one_hot(label, (self.num_classes, )) label2 = label[::-1] self.lam = nn.Variable((image.shape[0], 1, 1, 1)) if get_nnabla_version_integer() < 10700: raise ValueError( 'This does not work with nnabla version less than 1.7.0 due to [a bug](https://github.com/sony/nnabla/pull/608). Please update the nnabla version.' ) llam = F.reshape(self.lam, (-1, 1)) self.reset_mixup_ratio() # Call it for safe. mimage = self.lam * image + (1 - self.lam) * image2 mlabel = llam * label + (1 - llam) * label2 return mimage, mlabel
def random_generate(self, num_images, path): # Generate from the uniform prior of the base model indices = F.randint(low=0, high=self.num_embedding, shape=[num_images] + self.latent_shape) indices = F.reshape(indices, (-1, ), inplace=True) quantized = F.embed(indices, self.base_model.vq.embedding_weight) quantized = F.transpose( quantized.reshape([num_images] + self.latent_shape + [quantized.shape[-1]]), (0, 3, 1, 2)) img_gen_uniform_prior = self.base_model(quantized, quantized_as_input=True, test=True) # Generate images using pixelcnn prior indices = nn.Variable.from_numpy_array( np.zeros(shape=[num_images] + self.latent_shape)) labels = F.randint(low=0, high=self.num_classes, shape=(num_images, 1)) labels = F.one_hot(labels, shape=(self.num_classes, )) # Sample from pixelcnn - pixel by pixel import torch # Numpy behavior is different and not giving correct output for i in range(self.latent_shape[0]): for j in range(self.latent_shape[1]): quantized = F.embed(indices.reshape((-1, )), self.base_model.vq.embedding_weight) quantized = F.transpose( quantized.reshape([num_images] + self.latent_shape + [quantized.shape[-1]]), (0, 3, 1, 2)) indices_sample = self.prior(quantized, labels) indices_prob = F.reshape(indices_sample, indices.shape + (indices_sample.shape[-1], ), inplace=True)[:, i, j] indices_prob = F.softmax(indices_prob) indices_prob_tensor = torch.from_numpy(indices_prob.d) sample = indices_prob_tensor.multinomial(1).squeeze().numpy() indices[:, i, j] = sample print(indices.d) quantized = F.embed(indices.reshape((-1, )), self.base_model.vq.embedding_weight) quantized = F.transpose( quantized.reshape([num_images] + self.latent_shape + [quantized.shape[-1]]), (0, 3, 1, 2)) img_gen_pixelcnn_prior = self.base_model(quantized, quantized_as_input=True, test=True) self.save_image(img_gen_uniform_prior, os.path.join(path, 'generate_uniform.png')) self.save_image(img_gen_pixelcnn_prior, os.path.join(path, 'generate_pixelcnn.png')) print('Random labels generated for pixelcnn prior:', list(F.max(labels, axis=1, only_index=True).d))
def encode_inputs(inst_label, id_label, n_ids, use_encoder=False, channel_last=False): """ :param inst_label: (N, H, W) or (N, H, W, 1) :param id_label: (N, H, W) or (N, H, W, 1) :param use_encoder: boolean :return: """ # id (index) -> onehot _check_intput(id_label) if len(id_label.shape) == 3: id_label = id_label.reshape(id_label.shape + (1, )) id_onehot = F.one_hot(id_label, shape=(n_ids, )) # inst -> boundary map _check_intput(inst_label) bm = inst_to_boundary(inst_label) if len(bm.shape) == 3: bm = bm.reshape(bm.shape + (1, )) if use_encoder: # todo: implement encoder network pass if channel_last: return id_onehot, bm return F.transpose(id_onehot, (0, 3, 1, 2)), F.transpose(bm, (0, 3, 1, 2))
def create_network(batch_size, num_dilations, learning_rate): # model x = nn.Variable(shape=(batch_size, data_config.duration, 1)) # (B, T, 1) onehot = F.one_hot(x, shape=(data_config.q_bit_len, )) # (B, T, C) wavenet_input = F.transpose(onehot, (0, 2, 1)) # (B, C, T) # speaker embedding s_emb = None net = WaveNet(num_dilations) wavenet_output = net(wavenet_input, s_emb) pred = F.transpose(wavenet_output, (0, 2, 1)) # (B, T, 1) t = nn.Variable(shape=(batch_size, data_config.duration, 1)) loss = F.mean(F.softmax_cross_entropy(pred, t)) # loss.visit(PrintFunc()) # Create Solver. solver = S.Adam(learning_rate) solver.set_parameters(nn.get_parameters()) return x, t, loss, solver
def build_train_graph(self, batch): self.solver = S.Adam(self.learning_rate) obs, action, reward, terminal, newobs = batch # Create input variables s = nn.Variable(obs.shape) a = nn.Variable(action.shape) r = nn.Variable(reward.shape) t = nn.Variable(terminal.shape) snext = nn.Variable(newobs.shape) with nn.parameter_scope(self.name_q): q = self.q_builder(s, self.num_actions, test=False) self.solver.set_parameters(nn.get_parameters()) with nn.parameter_scope(self.name_qnext): qnext = self.q_builder(snext, self.num_actions, test=True) qnext.need_grad = False clipped_r = F.minimum_scalar(F.maximum_scalar( r, -self.clip_reward), self.clip_reward) q_a = F.sum( q * F.one_hot(F.reshape(a, (-1, 1), inplace=False), (q.shape[1],)), axis=1) target = clipped_r + self.gamma * (1 - t) * F.max(qnext, axis=1) loss = F.mean(F.huber_loss(q_a, target)) Variables = namedtuple( 'Variables', ['s', 'a', 'r', 't', 'snext', 'q', 'loss']) self.v = Variables(s, a, r, t, snext, q, loss) self.sync_models() self.built = True
def _build(self): # infer variable self.infer_obs_t = nn.Variable((1, 4, 84, 84)) # inference output self.infer_qs_t = self.q_function(self.infer_obs_t, self.num_actions, self.num_heads, 'q_func') self.infer_all = F.sink(*self.infer_qs_t) # train variables self.obss_t = nn.Variable((self.batch_size, 4, 84, 84)) self.acts_t = nn.Variable((self.batch_size, 1)) self.rews_tp1 = nn.Variable((self.batch_size, 1)) self.obss_tp1 = nn.Variable((self.batch_size, 4, 84, 84)) self.ters_tp1 = nn.Variable((self.batch_size, 1)) self.weights = nn.Variable((self.batch_size, self.num_heads)) # training output qs_t = self.q_function(self.obss_t, self.num_actions, self.num_heads, 'q_func') qs_tp1 = q_function(self.obss_tp1, self.num_actions, self.num_heads, 'target') stacked_qs_t = F.transpose(F.stack(*qs_t), [1, 0, 2]) stacked_qs_tp1 = F.transpose(F.stack(*qs_tp1), [1, 0, 2]) # select one dimension a_one_hot = F.reshape(F.one_hot(self.acts_t, (self.num_actions, )), (-1, 1, self.num_actions)) # mask output q_t_selected = F.sum(stacked_qs_t * a_one_hot, axis=2) q_tp1_best = F.max(stacked_qs_tp1, axis=2) q_tp1_best.need_grad = False # reward clipping clipped_rews_tp1 = clip_by_value(self.rews_tp1, -1.0, 1.0) # loss calculation y = clipped_rews_tp1 + self.gamma * q_tp1_best * (1.0 - self.ters_tp1) td = F.huber_loss(q_t_selected, y) self.loss = F.mean(F.sum(td * self.weights, axis=1)) # optimizer self.solver = S.RMSprop(self.lr, 0.95, 1e-2) # weights and biases with nn.parameter_scope('q_func'): self.params = nn.get_parameters() self.head_params = [] for i in range(self.num_heads): with nn.parameter_scope('head%d' % i): self.head_params.append(nn.get_parameters()) with nn.parameter_scope('shared'): self.shared_params = nn.get_parameters() with nn.parameter_scope('target'): self.target_params = nn.get_parameters() # set q function parameters to solver self.solver.set_parameters(self.params)
def test_one_hot_forward(seed, inshape, shape, ctx, func_name): rng = np.random.RandomState(seed) # Input input = rng.randint(0, shape[0], size=inshape) vinput = nn.Variable(input.shape, need_grad=False) vinput.d = input with nn.context_scope(ctx), nn.auto_forward(): o = F.one_hot(vinput, shape) r = ref_one_hot(input, shape) assert np.allclose(o.d, r) assert func_name == o.parent.name
def test_one_hot_forward(seed, inshape, shape, ctx, func_name): # Input input = np.zeros(inshape, dtype=int) rng = np.random.RandomState(seed) if len(shape) != inshape[-1]: # input inshape and shape don't match. with pytest.raises(RuntimeError): y = F.one_hot(nn.Variable(input.shape), shape) else: for i in range(inshape[-1]): input[:, i] = rng.randint(0, shape[i], size=inshape[0]) vinput = nn.Variable(input.shape, need_grad=False) vinput.d = input with nn.context_scope(ctx), nn.auto_forward(): o = F.one_hot(vinput, shape) r = ref_one_hot(input, shape) assert np.allclose(o.d, r) assert func_name == o.parent.name
def forward_pass(self, img_var, labels): enc_indices, quantized = self.base_model(img_var, return_encoding_indices=True, test=True) labels_var = nn.Variable(labels.shape) if isinstance(labels, nn.NdArray): labels_var.data = labels else: labels_var.d = labels labels_var = F.one_hot(labels_var, shape=(self.num_classes, )) enc_recon = self.prior(quantized, labels_var) loss = F.mean(F.softmax_cross_entropy(enc_recon, enc_indices)) return loss, enc_indices, enc_recon
def model_tweak_digitscaps(batch_size): ''' ''' image = nn.Variable((batch_size, 1, 28, 28)) label = nn.Variable((batch_size, 1)) x = image / 255.0 t_onehot = F.one_hot(label, (10,)) with nn.parameter_scope("capsnet"): _, _, _, caps, _ = model.capsule_net( x, test=True, aug=False, grad_dynamic_routing=True) noise = nn.Variable((batch_size, 1, caps.shape[2])) with nn.parameter_scope("capsnet_reconst"): recon = model.capsule_reconstruction(caps, t_onehot, noise) return image, label, noise, recon
def mlp_gradient_synthesizer(x, y=None, test=False): maps = x.shape[1] if y is not None: h = F.one_hot(y, (10, )) h = F.concatenate(*[x, y], axis=1) else: h = x with nn.parameter_scope("gs"): h = act_bn_linear(h, maps, test, name="fc0") h = act_bn_linear(h, maps, test, name="fc1") w_init = ConstantInitializer(0) b_init = ConstantInitializer(0) g_pred = PF.affine(h, maps, w_init=w_init, b_init=b_init, name="fc") g_pred.persistent = True return g_pred
def _build(self): # infer variable self.infer_obs_t = nn.Variable((1, 4, 84, 84)) # inference output self.infer_q_t = self.q_function(self.infer_obs_t, self.num_actions, scope='q_func') # train variables self.obss_t = nn.Variable((self.batch_size, 4, 84, 84)) self.acts_t = nn.Variable((self.batch_size, 1)) self.rews_tp1 = nn.Variable((self.batch_size, 1)) self.obss_tp1 = nn.Variable((self.batch_size, 4, 84, 84)) self.ters_tp1 = nn.Variable((self.batch_size, 1)) self.weights = nn.Variable((self.batch_size, 1)) # training output q_t = self.q_function(self.obss_t, self.num_actions, scope='q_func') q_tp1 = self.q_function(self.obss_tp1, self.num_actions, scope='target_q_func') # select one dimension a_t_one_hot = F.one_hot(self.acts_t, (self.num_actions, )) q_t_selected = F.sum(q_t * a_t_one_hot, axis=1, keepdims=True) q_tp1_best = F.max(q_tp1, axis=1, keepdims=True) # loss calculation y = self.rews_tp1 + self.gamma * q_tp1_best * (1.0 - self.ters_tp1) self.td = q_t_selected - y self.loss = F.sum(F.huber_loss(q_t_selected, y) * self.weights) self.loss_sink = F.sink(self.td, self.loss) # optimizer self.solver = S.RMSprop(self.lr, 0.95, 1e-2) # weights and biases with nn.parameter_scope('q_func'): self.params = nn.get_parameters() with nn.parameter_scope('target_q_func'): self.target_params = nn.get_parameters() # set q function parameters to solver self.solver.set_parameters(self.params)
def __call__(self, x, return_encoding_indices=False): x = F.transpose(x, (0, 2, 3, 1)) x_flat = x.reshape((-1, self.embedding_dim)) x_flat_squared = F.broadcast(F.sum(x_flat**2, axis=1, keepdims=True), (x_flat.shape[0], self.num_embedding)) emb_wt_squared = F.transpose( F.sum(self.embedding_weight**2, axis=1, keepdims=True), (1, 0)) distances = x_flat_squared + emb_wt_squared - 2 * \ F.affine(x_flat, F.transpose(self.embedding_weight, (1, 0))) encoding_indices = F.min(distances, only_index=True, axis=1, keepdims=True) encoding_indices.need_grad = False quantized = F.embed( encoding_indices.reshape(encoding_indices.shape[:-1]), self.embedding_weight).reshape(x.shape) if return_encoding_indices: return encoding_indices, F.transpose(quantized, (0, 3, 1, 2)) encodings = F.one_hot(encoding_indices, (self.num_embedding, )) e_latent_loss = F.mean( F.squared_error(quantized.get_unlinked_variable(need_grad=False), x)) q_latent_loss = F.mean( F.squared_error(quantized, x.get_unlinked_variable(need_grad=False))) loss = q_latent_loss + self.commitment_cost * e_latent_loss quantized = x + (quantized - x).get_unlinked_variable(need_grad=False) avg_probs = F.mean(encodings, axis=0) perplexity = F.exp(-F.sum(avg_probs * F.log(avg_probs + 1.0e-10))) return loss, F.transpose(quantized, (0, 3, 1, 2)), perplexity, encodings
def __init__(self, num_actions, num_envs, batch_size, v_coeff, ent_coeff, lr_scheduler): # inference graph self.infer_obs_t = nn.Variable((num_envs, 4, 84, 84)) self.infer_pi_t,\ self.infer_value_t = cnn_network(self.infer_obs_t, num_actions, 'network') self.infer_t = F.sink(self.infer_pi_t, self.infer_value_t) # evaluation graph self.eval_obs_t = nn.Variable((1, 4, 84, 84)) self.eval_pi_t, _ = cnn_network(self.eval_obs_t, num_actions, 'network') # training graph self.obss_t = nn.Variable((batch_size, 4, 84, 84)) self.acts_t = nn.Variable((batch_size, 1)) self.rets_t = nn.Variable((batch_size, 1)) self.advs_t = nn.Variable((batch_size, 1)) pi_t, value_t = cnn_network(self.obss_t, num_actions, 'network') # value loss l2loss = F.squared_error(value_t, self.rets_t) self.value_loss = v_coeff * F.mean(l2loss) # policy loss log_pi_t = F.log(pi_t + 1e-20) a_one_hot = F.one_hot(self.acts_t, (num_actions, )) log_probs_t = F.sum(log_pi_t * a_one_hot, axis=1, keepdims=True) self.pi_loss = F.mean(log_probs_t * self.advs_t) # KL loss entropy = -ent_coeff * F.mean(F.sum(pi_t * log_pi_t, axis=1)) self.loss = self.value_loss - self.pi_loss - entropy self.params = nn.get_parameters() self.solver = S.RMSprop(lr_scheduler(0.0), 0.99, 1e-5) self.solver.set_parameters(self.params) self.lr_scheduler = lr_scheduler
def cnn_gradient_synthesizer(x, y=None, test=False): bs = x.shape[0] maps = x.shape[1] s0, s1 = x.shape[2:] if y is not None: h = F.one_hot(y, (10, )) h = F.reshape(h, (bs, 10, 1, 1)) h = F.broadcast(h, (bs, 10, s0, s1)) h = F.concatenate(*[x, h], axis=1) else: h = x with nn.parameter_scope("gs"): h = act_bn_conv(h, maps, test, name="conv0") w_init = ConstantInitializer(0) b_init = ConstantInitializer(0) g_pred = PF.convolution(h, maps, kernel=(3, 3), pad=(1, 1), w_init=w_init, b_init=b_init, name="conv") g_pred.persistent = True return g_pred
def define_network(self): if self.use_inst: obj_onehot, bm = encode_inputs(self.ist_mask, self.obj_mask, n_ids=self.conf.n_class) mask = F.concatenate(obj_onehot, bm, axis=1) else: om = self.obj_mask if len(om.shape) == 3: om = F.reshape(om, om.shape + (1, )) obj_onehot = F.one_hot(om, shape=(self.conf.n_class, )) mask = F.transpose(obj_onehot, (0, 3, 1, 2)) generator = SpadeGenerator(self.conf.g_ndf, image_shape=self.conf.image_shape) z = F.randn(shape=(self.conf.batch_size, self.conf.z_dim)) fake = generator(z, mask) # Pixel intensities of fake are [-1, 1]. Rescale it to [0, 1] fake = (fake + 1) / 2 return fake
def network(x, d1, c1, d2, c2, test=False): # Input:x -> 1 # OneHot -> 687 h = F.one_hot(x, (687, )) # LSTM1 -> 200 with nn.parameter_scope('LSTM1'): h = network_LSTM(h, d1, c1, 687, 100, test) # Slice -> 100 h1 = F.slice(h, (0, ), (100, ), (1, )) # h2:CellOut -> 100 h2 = F.slice(h, (100, ), (200, ), (1, )) # LSTM2 -> 128 with nn.parameter_scope('LSTM2'): h3 = network_LSTM(h1, d2, c2, 100, 64, test) # h4:DelayOut h4 = F.identity(h1) # Slice_2 -> 64 h5 = F.slice(h3, (0, ), (64, ), (1, )) # h6:CellOut_2 -> 64 h6 = F.slice(h3, (64, ), (128, ), (1, )) # Affine_2 -> 687 h7 = PF.affine(h5, (687, ), name='Affine_2') # h8:DelayOut_2 h8 = F.identity(h5) # h7:Softmax h7 = F.softmax(h7) return h2, h4, h6, h8, h7
def train(): args = get_args() # Set context. from nnabla.ext_utils import get_extension_context logger.info("Running in {}:{}".format(args.context, args.type_config)) ctx = get_extension_context(args.context, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) data_iterator = data_iterator_librispeech(args.batch_size, args.data_dir) _data_source = data_iterator._data_source # dirty hack... # model x = nn.Variable( shape=(args.batch_size, data_config.duration, 1)) # (B, T, 1) onehot = F.one_hot(x, shape=(data_config.q_bit_len, )) # (B, T, C) wavenet_input = F.transpose(onehot, (0, 2, 1)) # (B, C, T) # speaker embedding if args.use_speaker_id: s_id = nn.Variable(shape=(args.batch_size, 1)) with nn.parameter_scope("speaker_embedding"): s_emb = PF.embed(s_id, n_inputs=_data_source.n_speaker, n_features=WavenetConfig.speaker_dims) s_emb = F.transpose(s_emb, (0, 2, 1)) else: s_emb = None net = WaveNet() wavenet_output = net(wavenet_input, s_emb) pred = F.transpose(wavenet_output, (0, 2, 1)) # (B, T, 1) t = nn.Variable(shape=(args.batch_size, data_config.duration, 1)) loss = F.mean(F.softmax_cross_entropy(pred, t)) # for generation prob = F.softmax(pred) # Create Solver. solver = S.Adam(args.learning_rate) solver.set_parameters(nn.get_parameters()) # Create monitor. monitor = Monitor(args.monitor_path) monitor_loss = MonitorSeries("Training loss", monitor, interval=10) # setup save env. audio_save_path = os.path.join(os.path.abspath( args.model_save_path), "audio_results") if audio_save_path and not os.path.exists(audio_save_path): os.makedirs(audio_save_path) # Training loop. for i in range(args.max_iter): # todo: validation x.d, _speaker, t.d = data_iterator.next() if args.use_speaker_id: s_id.d = _speaker.reshape(-1, 1) solver.zero_grad() loss.forward(clear_no_need_grad=True) loss.backward(clear_buffer=True) solver.update() loss.data.cast(np.float32, ctx) monitor_loss.add(i, loss.d.copy()) if i % args.model_save_interval == 0: prob.forward() audios = mu_law_decode( np.argmax(prob.d, axis=-1), quantize=data_config.q_bit_len) # (B, T) save_audio(audios, i, audio_save_path)
def cond_att_lstm(x, parent_index, mask, context, context_mask, state_size, att_hidden_size, initial_state=None, initial_cell=None, hist=None, dropout=0, train=True, w_init=None, inner_w_init=None, b_init=I.ConstantInitializer(0), forget_bias_init=I.ConstantInitializer(1)): """ x: (batch_size, length, input_size) parent_index: (batch_size, length) mask: (batch_size, length) context: (batch_size, context_length, context_size) context_mask: (batch_size, context_length) hist: (batch_size, l, state_size) """ batch_size, length, input_size = x.shape _, context_length, context_size = context.shape if w_init is None: w_init = I.UniformInitializer( I.calc_uniform_lim_glorot(input_size, state_size)) if inner_w_init is None: inner_w_init = orthogonal retain_prob = 1.0 - dropout z_w = nn.Variable((batch_size, 4, input_size), need_grad=False) z_w.d = 1 z_u = nn.Variable((batch_size, 4, state_size), need_grad=False) z_u.d = 1 if dropout > 0: if train: z_w = F.dropout(z_w, p=retain_prob) z_u = F.dropout(z_u, p=retain_prob) z_w *= retain_prob z_u *= retain_prob z_w = F.reshape(z_w, (batch_size, 4, 1, input_size)) z_w = F.broadcast(z_w, (batch_size, 4, length, input_size)) z_w = F.split(z_w, axis=1) z_u = F.split(z_u, axis=1) xi = z_w[0] * x xf = z_w[1] * x xc = z_w[2] * x xo = z_w[3] * x with nn.parameter_scope("cond_att_lstm"): # (batch_size, length, state_size) with nn.parameter_scope("lstm"): xi = PF.affine( xi, state_size, base_axis=2, w_init=w_init, b_init=b_init, name="Wi") xf = PF.affine( xf, state_size, base_axis=2, w_init=w_init, b_init=forget_bias_init, name="Wf") xc = PF.affine( xc, state_size, base_axis=2, w_init=w_init, b_init=b_init, name="Wc") xo = PF.affine( xo, state_size, base_axis=2, w_init=w_init, b_init=b_init, name="Wo") with nn.parameter_scope("context"): # context_att_trans: (batch_size, context_size, att_hidden_size) context_att_trans = PF.affine( context, att_hidden_size, base_axis=2, w_init=w_init, b_init=b_init, name="layer1_c") if initial_state is None: h = nn.Variable((batch_size, state_size), need_grad=False) h.data.zero() else: h = initial_state if initial_cell is None: c = nn.Variable((batch_size, state_size), need_grad=False) c.data.zero() else: c = initial_cell if hist is None: hist = nn.Variable((batch_size, 1, state_size), need_grad=False) hist.data.zero() # (batch_size, state_size) xi = split(xi, axis=1) xf = split(xf, axis=1) xc = split(xc, axis=1) xo = split(xo, axis=1) mask = F.reshape(mask, [batch_size, length, 1]) # (batch_size, length, 1) mask = F.broadcast(mask, [batch_size, length, state_size]) # (batch_size, state_size) mask = split(mask, axis=1) # (batch_size, max_action_length) parent_index = parent_index + 1 # index == 0 means that parent is root # (batch_size) parent_index = split(parent_index, axis=1) hs = [] cs = [] ctx = [] for i, f, c2, o, m, p in zip(xi, xf, xc, xo, mask, parent_index): h_num = hist.shape[1] with nn.parameter_scope("context"): h_att_trans = PF.affine( h, att_hidden_size, with_bias=False, w_init=w_init, name="layer1_h") # (batch_size, att_hidden_size) h_att_trans = F.reshape(h_att_trans, (batch_size, 1, att_hidden_size)) h_att_trans = F.broadcast( h_att_trans, (batch_size, context_length, att_hidden_size)) att_hidden = F.tanh(context_att_trans + h_att_trans) att_raw = PF.affine( att_hidden, 1, base_axis=2, w_init=w_init, b_init=b_init) # (batch_size, context_length, 1) att_raw = F.reshape(att_raw, (batch_size, context_length)) ctx_att = F.exp(att_raw - F.max(att_raw, axis=1, keepdims=True)) ctx_att = ctx_att * context_mask ctx_att = ctx_att / F.sum(ctx_att, axis=1, keepdims=True) ctx_att = F.reshape(ctx_att, (batch_size, context_length, 1)) ctx_att = F.broadcast(ctx_att, (batch_size, context_length, context_size)) ctx_vec = F.sum( context * ctx_att, axis=1) # (batch_size, context_size) # parent_history p = F.reshape(p, (batch_size, 1)) p = F.one_hot(p, (h_num, )) p = F.reshape(p, (batch_size, 1, h_num)) par_h = F.batch_matmul(p, hist) # [batch_size, 1, state_size] par_h = F.reshape(par_h, (batch_size, state_size)) with nn.parameter_scope("lstm"): i_t = PF.affine( z_u[0] * h, state_size, w_init=inner_w_init(state_size, state_size), with_bias=False, name="Ui") i_t += PF.affine( ctx_vec, state_size, w_init=inner_w_init(context_size, state_size), with_bias=False, name="Ci") i_t += PF.affine( par_h, state_size, w_init=inner_w_init(state_size, state_size), with_bias=False, name="Pi") i_t = F.sigmoid(i + i_t) f_t = PF.affine( z_u[1] * h, state_size, w_init=inner_w_init(state_size, state_size), with_bias=False, name="Uf") f_t += PF.affine( ctx_vec, state_size, w_init=inner_w_init(context_size, state_size), with_bias=False, name="Cf") f_t += PF.affine( par_h, state_size, w_init=inner_w_init(state_size, state_size), with_bias=False, name="Pf") f_t = F.sigmoid(f + f_t) c_t = PF.affine( z_u[2] * h, state_size, w_init=inner_w_init(state_size, state_size), with_bias=False, name="Uc") c_t += PF.affine( ctx_vec, state_size, w_init=inner_w_init(context_size, state_size), with_bias=False, name="Cc") c_t += PF.affine( par_h, state_size, w_init=inner_w_init(state_size, state_size), with_bias=False, name="Pc") c_t = f_t * c + i_t * F.tanh(c2 + c_t) o_t = PF.affine( z_u[3] * h, state_size, w_init=inner_w_init(state_size, state_size), with_bias=False, name="Uo") o_t += PF.affine( ctx_vec, state_size, w_init=inner_w_init(context_size, state_size), with_bias=False, name="Co") o_t += PF.affine( par_h, state_size, w_init=inner_w_init(state_size, state_size), with_bias=False, name="Po") o_t = F.sigmoid(o + o_t) h_t = o_t * F.tanh(c_t) h_t = (1 - m) * h + m * h_t c_t = (1 - m) * c + m * c_t h = h_t c = c_t h_t = F.reshape(h_t, (batch_size, 1, state_size), inplace=False) c_t = F.reshape(c_t, (batch_size, 1, state_size), inplace=False) ctx_vec = F.reshape( ctx_vec, (batch_size, 1, context_size), inplace=False) hs.append(h_t) cs.append(c_t) ctx.append(ctx_vec) hist = F.concatenate( hist, h_t, axis=1) # (batch_size, h_num + 1, state_size) return concatenate( *hs, axis=1), concatenate( *cs, axis=1), concatenate( *ctx, axis=1), hist
def train(): ''' Main script. ''' args = get_args() from numpy.random import seed seed(0) # Get context. from nnabla.ext_utils import get_extension_context logger.info("Running in %s" % args.context) ctx = get_extension_context(args.context, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) # TRAIN image = nn.Variable([args.batch_size, 1, 28, 28]) label = nn.Variable([args.batch_size, 1]) x = image / 255.0 t_onehot = F.one_hot(label, (10, )) with nn.parameter_scope("capsnet"): c1, pcaps, u_hat, caps, pred = model.capsule_net( x, test=False, aug=True, grad_dynamic_routing=args.grad_dynamic_routing) with nn.parameter_scope("capsnet_reconst"): recon = model.capsule_reconstruction(caps, t_onehot) loss_margin, loss_reconst, loss = model.capsule_loss( pred, t_onehot, recon, x) pred.persistent = True # TEST # Create input variables. vimage = nn.Variable([args.batch_size, 1, 28, 28]) vlabel = nn.Variable([args.batch_size, 1]) vx = vimage / 255.0 with nn.parameter_scope("capsnet"): _, _, _, _, vpred = model.capsule_net(vx, test=True, aug=False) # Create Solver. solver = S.Adam(args.learning_rate) solver.set_parameters(nn.get_parameters()) # Create monitor. from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed train_iter = int(60000 / args.batch_size) val_iter = int(10000 / args.batch_size) logger.info("#Train: {} #Validation: {}".format(train_iter, val_iter)) monitor = Monitor(args.monitor_path) monitor_loss = MonitorSeries("Training loss", monitor, interval=1) monitor_mloss = MonitorSeries("Training margin loss", monitor, interval=1) monitor_rloss = MonitorSeries("Training reconstruction loss", monitor, interval=1) monitor_err = MonitorSeries("Training error", monitor, interval=1) monitor_time = MonitorTimeElapsed("Training time", monitor, interval=1) monitor_verr = MonitorSeries("Test error", monitor, interval=1) monitor_lr = MonitorSeries("Learning rate", monitor, interval=1) # To_save_nnp m_image, m_label, m_noise, m_recon = model_tweak_digitscaps( args.batch_size) contents = save_nnp({ 'x1': m_image, 'x2': m_label, 'x3': m_noise }, {'y': m_recon}, args.batch_size) save.save(os.path.join(args.monitor_path, 'capsnet_epoch0_result.nnp'), contents) # Initialize DataIterator for MNIST. from numpy.random import RandomState data = data_iterator_mnist(args.batch_size, True, rng=RandomState(1223)) vdata = data_iterator_mnist(args.batch_size, False) start_point = 0 if args.checkpoint is not None: # load weights and solver state info from specified checkpoint file. start_point = load_checkpoint(args.checkpoint, solver) # Training loop. for e in range(start_point, args.max_epochs): # Learning rate decay learning_rate = solver.learning_rate() if e != 0: learning_rate *= 0.9 solver.set_learning_rate(learning_rate) monitor_lr.add(e, learning_rate) # Training train_error = 0.0 train_loss = 0.0 train_mloss = 0.0 train_rloss = 0.0 for i in range(train_iter): image.d, label.d = data.next() solver.zero_grad() loss.forward(clear_no_need_grad=True) loss.backward(clear_buffer=True) solver.update() train_error += categorical_error(pred.d, label.d) train_loss += loss.d train_mloss += loss_margin.d train_rloss += loss_reconst.d train_error /= train_iter train_loss /= train_iter train_mloss /= train_iter train_rloss /= train_iter # Validation val_error = 0.0 for j in range(val_iter): vimage.d, vlabel.d = vdata.next() vpred.forward(clear_buffer=True) val_error += categorical_error(vpred.d, vlabel.d) val_error /= val_iter # Monitor monitor_time.add(e) monitor_loss.add(e, train_loss) monitor_mloss.add(e, train_mloss) monitor_rloss.add(e, train_rloss) monitor_err.add(e, train_error) monitor_verr.add(e, val_error) save_checkpoint(args.monitor_path, e, solver) # To_save_nnp contents = save_nnp({ 'x1': m_image, 'x2': m_label, 'x3': m_noise }, {'y': m_recon}, args.batch_size) save.save(os.path.join(args.monitor_path, 'capsnet_result.nnp'), contents)
def train(): """ Main script. Steps: * Parse command line arguments. * Specify contexts for computation. * Initialize DataIterator. * Construct a computation graph for training and one for validation. * Initialize solver and set parameter variables to that. * Create monitor instances for saving and displaying training stats. * Training loop * Computate error rate for validation data (periodically) * Get a next minibatch. * Execute forwardprop * Set parameter gradients zero * Execute backprop. * Solver updates parameters by using gradients computed by backprop. * Compute training error """ # Parse args args = get_args() n_valid_samples = 10000 bs_valid = args.batch_size extension_module = args.context ctx = get_extension_context(extension_module, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) # Dataset data_iterator = data_iterator_cifar10 n_class = 10 # Model architecture if args.net == "resnet18": prediction = functools.partial(resnet18_prediction, ncls=n_class, nmaps=64, act=F.relu) if args.net == "resnet34": prediction = functools.partial(resnet34_prediction, ncls=n_class, nmaps=64, act=F.relu) # Create training graphs test = False if args.mixtype == "mixup": mdl = MixupLearning(args.batch_size, alpha=args.alpha) elif args.mixtype == "cutmix": mdl = CutmixLearning((args.batch_size, 3, 32, 32), alpha=args.alpha, cutmix_prob=1.0) elif args.mixtype == "vhmixup": mdl = VHMixupLearning((args.batch_size, 3, 32, 32), alpha=args.alpha) else: print("[ERROR] Unknown mixtype: " + args.mixtype) return image_train = nn.Variable((args.batch_size, 3, 32, 32)) label_train = nn.Variable((args.batch_size, 1)) mix_image, mix_label = mdl.mix_data(single_image_augment(image_train), F.one_hot(label_train, (n_class, ))) pred_train = prediction(mix_image, test) loss_train = mdl.loss(pred_train, mix_label) input_train = {"image": image_train, "label": label_train} # Create validation graph test = True image_valid = nn.Variable((bs_valid, 3, 32, 32)) pred_valid = prediction(image_valid, test) input_valid = {"image": image_valid} # Solvers if args.solver == "Adam": solver = S.Adam() elif args.solver == "Momentum": solver = S.Momentum(lr=args.learning_rate) solver.set_parameters(nn.get_parameters()) # Create monitor from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed monitor = Monitor(args.save_path) monitor_loss = MonitorSeries("Training loss", monitor, interval=10) monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10) monitor_verr = MonitorSeries("Test error", monitor, interval=1) # Data Iterator tdata = data_iterator(args.batch_size, True) vdata = data_iterator(args.batch_size, False) print("Size of the training data: %d " % tdata.size) # Training-loop for i in range(args.max_iter): # Forward/Zerograd/Backward image, label = tdata.next() input_train["image"].d = image input_train["label"].d = label mdl.set_mix_ratio() loss_train.forward() solver.zero_grad() loss_train.backward() # Model update by solver if args.solver == "Momentum": if i == args.max_iter / 2: solver.set_learning_rate(args.learning_rate / 10.0) if i == args.max_iter / 4 * 3: solver.set_learning_rate(args.learning_rate / 10.0**2) solver.update() # Validation if (i + 1) % args.val_interval == 0 or i == 0: ve = 0. vdata._reset() vdata_pred = np.zeros((n_valid_samples, n_class)) vdata_label = np.zeros((n_valid_samples, 1), dtype=np.int32) for j in range(0, n_valid_samples, args.batch_size): image, label = vdata.next() input_valid["image"].d = image pred_valid.forward() vdata_pred[j:min(j + args.batch_size, n_valid_samples )] = pred_valid.d[:min( args.batch_size, n_valid_samples - j)] vdata_label[j:min(j + args.batch_size, n_valid_samples )] = label[:min(args. batch_size, n_valid_samples - j)] ve = categorical_error(vdata_pred, vdata_label) monitor_verr.add(i + 1, ve) if int((i + 1) % args.model_save_interval) == 0: nn.save_parameters( os.path.join(args.save_path, 'params_%06d.h5' % (i + 1))) # Monitering monitor_loss.add(i + 1, loss_train.d.copy()) monitor_time.add(i + 1) nn.save_parameters( os.path.join(args.save_path, 'params_%06d.h5' % (args.max_iter)))
def train(): rng = np.random.RandomState(803) conf = get_config() comm = init_nnabla(conf) # create data iterator if conf.dataset == "cityscapes": data_list = get_cityscape_datalist(conf.cityscapes, save_file=comm.rank == 0) n_class = conf.cityscapes.n_label_ids use_inst = True data_iter = create_cityscapes_iterator(conf.batch_size, data_list, comm=comm, image_shape=conf.image_shape, rng=rng, flip=conf.use_flip) elif conf.dataset == "ade20k": data_list = get_ade20k_datalist(conf.ade20k, save_file=comm.rank == 0) n_class = conf.ade20k.n_label_ids + 1 # class id + unknown use_inst = False load_shape = tuple( x + 30 for x in conf.image_shape) if conf.use_crop else conf.image_shape data_iter = create_ade20k_iterator(conf.batch_size, data_list, comm=comm, load_shape=load_shape, crop_shape=conf.image_shape, rng=rng, flip=conf.use_flip) else: raise NotImplementedError( "Currently dataset {} is not supported.".format(conf.dataset)) real = nn.Variable(shape=(conf.batch_size, 3) + conf.image_shape) obj_mask = nn.Variable(shape=(conf.batch_size, ) + conf.image_shape) if use_inst: ist_mask = nn.Variable(shape=(conf.batch_size, ) + conf.image_shape) obj_onehot, bm = encode_inputs(ist_mask, obj_mask, n_ids=n_class) mask = F.concatenate(obj_onehot, bm, axis=1) else: om = obj_mask if len(om.shape) == 3: om = F.reshape(om, om.shape + (1, )) obj_onehot = F.one_hot(om, shape=(n_class, )) mask = F.transpose(obj_onehot, (0, 3, 1, 2)) # generator generator = SpadeGenerator(conf.g_ndf, image_shape=conf.image_shape) z = F.randn(shape=(conf.batch_size, conf.z_dim)) fake = generator(z, mask) # unlinking ul_mask, ul_fake = get_unlinked_all(mask, fake) # discriminator discriminator = PatchGAN(n_scales=conf.d_n_scales) d_input_real = F.concatenate(real, ul_mask, axis=1) d_input_fake = F.concatenate(ul_fake, ul_mask, axis=1) d_real_out, d_real_feats = discriminator(d_input_real) d_fake_out, d_fake_feats = discriminator(d_input_fake) g_gan, g_feat, d_real, d_fake = discriminator.get_loss( d_real_out, d_real_feats, d_fake_out, d_fake_feats, use_fm=conf.use_fm, fm_lambda=conf.lambda_fm, gan_loss_type=conf.gan_loss_type) def _rescale(x): return rescale_values(x, input_min=-1, input_max=1, output_min=0, output_max=255) g_vgg = vgg16_perceptual_loss(_rescale(ul_fake), _rescale(real)) * conf.lambda_vgg set_persistent_all(fake, mask, g_gan, g_feat, d_real, d_fake, g_vgg) # loss g_loss = g_gan + g_feat + g_vgg d_loss = (d_real + d_fake) / 2 # load params if conf.load_params is not None: print("load parameters from {}".format(conf.load_params)) nn.load_parameters(conf.load_params) # Setup Solvers g_solver = S.Adam(beta1=0.) g_solver.set_parameters(get_params_startswith("spade_generator")) d_solver = S.Adam(beta1=0.) d_solver.set_parameters(get_params_startswith("discriminator")) # lr scheduler g_lrs = LinearDecayScheduler(start_lr=conf.g_lr, end_lr=0., start_iter=100, end_iter=200) d_lrs = LinearDecayScheduler(start_lr=conf.d_lr, end_lr=0., start_iter=100, end_iter=200) ipe = get_iteration_per_epoch(data_iter._size, conf.batch_size, round="ceil") if not conf.show_interval: conf.show_interval = ipe if not conf.save_interval: conf.save_interval = ipe if not conf.niter: conf.niter = 200 * ipe # Setup Reporter losses = { "g_gan": g_gan, "g_feat": g_feat, "g_vgg": g_vgg, "d_real": d_real, "d_fake": d_fake } reporter = Reporter(comm, losses, conf.save_path, nimage_per_epoch=min(conf.batch_size, 5), show_interval=10) progress_iterator = trange(conf.niter, disable=comm.rank > 0) reporter.start(progress_iterator) colorizer = Colorize(n_class) # output all config and dump to file if comm.rank == 0: conf.dump_to_stdout() write_yaml(os.path.join(conf.save_path, "config.yaml"), conf) epoch = 0 for itr in progress_iterator: if itr % ipe == 0: g_lr = g_lrs(epoch) d_lr = d_lrs(epoch) g_solver.set_learning_rate(g_lr) d_solver.set_learning_rate(d_lr) if comm.rank == 0: print( "\n[epoch {}] update lr to ... g_lr: {}, d_lr: {}".format( epoch, g_lr, d_lr)) epoch += 1 if conf.dataset == "cityscapes": im, ist, obj = data_iter.next() ist_mask.d = ist elif conf.dataset == "ade20k": im, obj = data_iter.next() else: raise NotImplemented() real.d = im obj_mask.d = obj # text embedding and create fake fake.forward() # update discriminator d_solver.zero_grad() d_loss.forward() d_loss.backward(clear_buffer=True) comm.all_reduced_solver_update(d_solver) # update generator ul_fake.grad.zero() g_solver.zero_grad() g_loss.forward() g_loss.backward(clear_buffer=True) # backward generator fake.backward(grad=None, clear_buffer=True) comm.all_reduced_solver_update(g_solver) # report iteration progress reporter() # report epoch progress show_epoch = itr // conf.show_interval if (itr % conf.show_interval) == 0: show_images = { "RealImages": real.data.get_data("r").transpose((0, 2, 3, 1)), "ObjectMask": colorizer(obj).astype(np.uint8), "GeneratedImage": fake.data.get_data("r").transpose( (0, 2, 3, 1)) } reporter.step(show_epoch, show_images) if (itr % conf.save_interval) == 0 and comm.rank == 0: nn.save_parameters( os.path.join(conf.save_path, 'param_{:03d}.h5'.format(show_epoch))) if comm.rank == 0: nn.save_parameters(os.path.join(conf.save_path, 'param_final.h5'))
def train(): ''' Main script. ''' args = get_args() from numpy.random import seed seed(0) # Get context. from nnabla.contrib.context import extension_context extension_module = args.context if args.context is None: extension_module = 'cpu' logger.info("Running in %s" % extension_module) ctx = extension_context(extension_module, device_id=args.device_id) nn.set_default_context(ctx) # TRAIN image = nn.Variable([args.batch_size, 1, 28, 28]) label = nn.Variable([args.batch_size, 1]) x = image / 255.0 t_onehot = F.one_hot(label, (10, )) with nn.parameter_scope("capsnet"): c1, pcaps, u_hat, caps, pred = model.capsule_net( x, test=False, aug=True, grad_dynamic_routing=args.grad_dynamic_routing) with nn.parameter_scope("capsnet_reconst"): recon = model.capsule_reconstruction(caps, t_onehot) loss_margin, loss_reconst, loss = model.capsule_loss( pred, t_onehot, recon, x) pred.persistent = True # TEST # Create input variables. vimage = nn.Variable([args.batch_size, 1, 28, 28]) vlabel = nn.Variable([args.batch_size, 1]) vx = vimage / 255.0 with nn.parameter_scope("capsnet"): _, _, _, _, vpred = model.capsule_net(vx, test=True, aug=False) # Create Solver. solver = S.Adam(args.learning_rate) solver.set_parameters(nn.get_parameters()) # Create monitor. from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed train_iter = int(60000 / args.batch_size) val_iter = int(10000 / args.batch_size) logger.info("#Train: {} #Validation: {}".format(train_iter, val_iter)) monitor = Monitor(args.monitor_path) monitor_loss = MonitorSeries("Training loss", monitor, interval=1) monitor_mloss = MonitorSeries("Training margin loss", monitor, interval=1) monitor_rloss = MonitorSeries("Training reconstruction loss", monitor, interval=1) monitor_err = MonitorSeries("Training error", monitor, interval=1) monitor_time = MonitorTimeElapsed("Training time", monitor, interval=1) monitor_verr = MonitorSeries("Test error", monitor, interval=1) monitor_lr = MonitorSeries("Learning rate", monitor, interval=1) # Initialize DataIterator for MNIST. from numpy.random import RandomState data = data_iterator_mnist(args.batch_size, True, rng=RandomState(1223)) vdata = data_iterator_mnist(args.batch_size, False) # Training loop. for e in range(args.max_epochs): # Learning rate decay learning_rate = solver.learning_rate() if e != 0: learning_rate *= 0.9 solver.set_learning_rate(learning_rate) monitor_lr.add(e, learning_rate) # Training train_error = 0.0 train_loss = 0.0 train_mloss = 0.0 train_rloss = 0.0 for i in range(train_iter): image.d, label.d = data.next() solver.zero_grad() loss.forward(clear_no_need_grad=True) loss.backward(clear_buffer=True) solver.update() train_error += categorical_error(pred.d, label.d) train_loss += loss.d train_mloss += loss_margin.d train_rloss += loss_reconst.d train_error /= train_iter train_loss /= train_iter train_mloss /= train_iter train_rloss /= train_iter # Validation val_error = 0.0 for j in range(val_iter): vimage.d, vlabel.d = vdata.next() vpred.forward(clear_buffer=True) val_error += categorical_error(vpred.d, vlabel.d) val_error /= val_iter # Monitor monitor_time.add(e) monitor_loss.add(e, train_loss) monitor_mloss.add(e, train_mloss) monitor_rloss.add(e, train_rloss) monitor_err.add(e, train_error) monitor_verr.add(e, val_error) nn.save_parameters( os.path.join(args.monitor_path, 'params_%06d.h5' % e))
def main(args): # Settings device_id = args.device_id batch_size = args.batch_size batch_size_eval = args.batch_size_eval n_l_train_data = 4000 n_train_data = 50000 n_cls = 10 learning_rate = 1. * 1e-3 n_epoch = 300 act = F.relu iter_epoch = int(n_train_data / batch_size) n_iter = n_epoch * iter_epoch extension_module = args.context alpha = args.alpha # Supervised Model ## ERM batch_size, m, h, w = batch_size, 3, 32, 32 ctx = extension_context(extension_module, device_id=device_id) x_l_0 = nn.Variable((batch_size, m, h, w)) y_l_0 = nn.Variable((batch_size, 1)) pred = cnn_model_003(ctx, x_l_0) loss_ce = ce_loss(ctx, pred, y_l_0) loss_er = er_loss(ctx, pred) loss_supervised = loss_ce + loss_er ## VRM (mixup) x_l_1 = nn.Variable((batch_size, m, h, w)) y_l_1 = nn.Variable((batch_size, 1)) coef = nn.Variable() coef_b = F.broadcast(coef.reshape([1]*x_l_0.ndim, unlink=True), x_l_0.shape) x_l_m = coef_b * x_l_0 + (1 - coef_b) * x_l_1 coef_b = F.broadcast(coef.reshape([1]*pred.ndim, unlink=True), pred.shape) y_l_m = coef_b * F.one_hot(y_l_0, (n_cls, )) \ + (1-coef_b) * F.one_hot(y_l_1, (n_cls, )) x_l_m.need_grad, y_l_m.need_grad = False, False pred_m = cnn_model_003(ctx, x_l_m) loss_er_m = er_loss(ctx, pred_m) #todo: need? loss_ce_m = ce_loss_soft(ctx, pred, y_l_m) loss_supervised_m = loss_ce_m #+ loss_er_m # Semi-Supervised Model ## ERM x_u0 = nn.Variable((batch_size, m, h, w)) x_u1 = nn.Variable((batch_size, m, h, w)) pred_x_u0 = cnn_model_003(ctx, x_u0) pred_x_u1 = cnn_model_003(ctx, x_u1) pred_x_u0.persistent, pred_x_u1.persistent = True, True loss_sr = sr_loss(ctx, pred_x_u0, pred_x_u1) loss_er0 = er_loss(ctx, pred_x_u0) loss_er1 = er_loss(ctx, pred_x_u1) loss_unsupervised = loss_sr + loss_er0 + loss_er1 ## VRM (mixup) x_u2 = nn.Variable((batch_size, m, h, w)) # not to overwrite x_u1.d coef_u = nn.Variable() coef_u_b = F.broadcast(coef_u.reshape([1]*x_u0.ndim, unlink=True), x_u0.shape) x_u_m = coef_u_b * x_u0 + (1-coef_u_b) * x_u2 pred_x_u0_ = nn.Variable(pred_x_u0.shape) # unlink forward pass but reuse result pred_x_u1_ = nn.Variable(pred_x_u1.shape) pred_x_u0_.data = pred_x_u0.data pred_x_u1_.data = pred_x_u1.data coef_u_b = F.broadcast(coef_u.reshape([1]*pred_x_u0.ndim, unlink=True), pred_x_u0.shape) y_u_m = coef_u_b * pred_x_u0_ + (1-coef_u_b) * pred_x_u1_ x_u_m.need_grad, y_u_m.need_grad = False, False pred_x_u_m = cnn_model_003(ctx, x_u_m) loss_er_u_m = er_loss(ctx, pred_x_u_m) #todo: need? loss_ce_u_m = ce_loss_soft(ctx, pred_x_u_m, y_u_m) loss_unsupervised_m = loss_ce_u_m #+ loss_er_u_m # Evaluatation Model batch_size_eval, m, h, w = batch_size, 3, 32, 32 x_eval = nn.Variable((batch_size_eval, m, h, w)) pred_eval = cnn_model_003(ctx, x_eval, test=True) # Solver with nn.context_scope(ctx): solver = S.Adam(alpha=learning_rate) solver.set_parameters(nn.get_parameters()) # Dataset ## separate dataset home = os.environ.get("HOME") fpath = os.path.join(home, "datasets/cifar10/cifar-10.npz") separator = Separator(n_l_train_data) separator.separate_then_save(fpath) l_train_path = os.path.join(home, "datasets/cifar10/l_cifar-10.npz") u_train_path = os.path.join(home, "datasets/cifar10/cifar-10.npz") test_path = os.path.join(home, "datasets/cifar10/cifar-10.npz") # data reader data_reader = Cifar10DataReader(l_train_path, u_train_path, test_path, batch_size=batch_size, n_cls=n_cls, da=True, shape=True) # Training loop print("# Training loop") epoch = 1 st = time.time() acc_prev = 0. ve_best = 1. save_path_prev = "" for i in range(n_iter): # Get data and set it to the varaibles x_l0_data, x_l1_data, y_l_data = data_reader.get_l_train_batch() x_u0_data, x_u1_data, y_u_data = data_reader.get_u_train_batch() x_l_0.d, _ , y_l_0.d= x_l0_data, x_l1_data, y_l_data x_u0.d, x_u1.d= x_u0_data, x_u1_data # Train ## forward (supervised and its mixup) loss_supervised.forward(clear_no_need_grad=True) coef_data = np.random.beta(alpha, alpha) coef.d = coef_data x_l_1.d = np.random.permutation(x_l0_data) y_l_1.d = np.random.permutation(y_l_data) loss_supervised_m.forward(clear_no_need_grad=True) ## forward (unsupervised and its mixup) loss_unsupervised.forward(clear_no_need_grad=True) coef_data = np.random.beta(alpha, alpha) coef_u.d = coef_data x_u2.d = np.random.permutation(x_u1_data) loss_unsupervised_m.forward(clear_no_need_grad=True) ## backward solver.zero_grad() loss_supervised.backward(clear_buffer=False) loss_supervised_m.backward(clear_buffer=False) loss_unsupervised.backward(clear_buffer=False) loss_unsupervised_m.backward(clear_buffer=True) solver.update() # Evaluate if int((i+1) % iter_epoch) == 0: # Get data and set it to the varaibles x_data, y_data = data_reader.get_test_batch() # Evaluation loop ve = 0. iter_val = 0 for k in range(0, len(x_data), batch_size_eval): x_eval.d = get_test_data(x_data, k, batch_size_eval) label = get_test_data(y_data, k, batch_size_eval) pred_eval.forward(clear_buffer=True) ve += categorical_error(pred_eval.d, label) iter_val += 1 ve /= iter_val msg = "Epoch:{},ElapsedTime:{},Acc:{:02f}".format( epoch, time.time() - st, (1. - ve) * 100) print(msg) if ve < ve_best: if not os.path.exists(args.model_save_path): os.makedirs(args.model_save_path) if save_path_prev != "": os.remove(save_path_prev) save_path = os.path.join( args.model_save_path, 'params_%06d.h5' % epoch) nn.save_parameters(save_path) save_path_prev = save_path ve_best = ve st = time.time() epoch +=1
def train(): if Config.USE_NW: env = Environment('Pong-v0') else: env = gym.make('Pong-v0') extension_module = Config.CONTEXT logger.info("Running in {}".format(extension_module)) ctx = extension_context(extension_module, device_id=Config.DEVICE_ID) nn.set_default_context(ctx) monitor = Monitor(Config.MONITOR_PATH) monitor_loss = MonitorSeries("Training loss", monitor, interval=1) monitor_reward = MonitorSeries("Training reward", monitor, interval=1) monitor_q = MonitorSeries("Training q", monitor, interval=1) monitor_time = MonitorTimeElapsed("Training time", monitor, interval=1) # placeholder image = nn.Variable([ Config.BATCH_SIZE, Config.STATE_LENGTH, Config.FRAME_WIDTH, Config.FRAME_HEIGHT ]) image_target = nn.Variable([ Config.BATCH_SIZE, Config.STATE_LENGTH, Config.FRAME_WIDTH, Config.FRAME_HEIGHT ]) nn.clear_parameters() # create network with nn.parameter_scope("dqn"): q = dqn(image, test=False) q.prersistent = True # Not to clear at backward with nn.parameter_scope("target"): target_q = dqn(image_target, test=False) target_q.prersistent = True # Not to clear at backward # loss definition a = nn.Variable([Config.BATCH_SIZE, 1]) q_val = F.sum(F.one_hot(a, (6, )) * q, axis=1, keepdims=True) t = nn.Variable([Config.BATCH_SIZE, 1]) loss = F.mean(F.squared_error(t, q_val)) if Config.RESUME: logger.info('load model: {}'.format(Config.RESUME)) nn.load_parameters(Config.RESUME) # setup solver # update dqn parameter only solver = S.RMSprop(lr=Config.LEARNING_RATE, decay=Config.DECAY, eps=Config.EPSILON) with nn.parameter_scope("dqn"): solver.set_parameters(nn.get_parameters()) # training epsilon = Config.INIT_EPSILON experiences = [] step = 0 for i in range(Config.EPISODE_LENGTH): logger.info("EPISODE {}".format(i)) done = False observation = env.reset() for i in range(30): observation_next, reward, done, info = env.step(0) observation_next = preprocess_frame(observation_next) # join 4 frame state = [observation_next for _ in xrange(Config.STATE_LENGTH)] state = np.stack(state, axis=0) total_reward = 0 while not done: # select action if step % Config.ACTION_INTERVAL == 0: if random.random() > epsilon or len( experiences) >= Config.REPLAY_MEMORY_SIZE: # inference image.d = state q.forward() action = np.argmax(q.d) else: # random action if Config.USE_NW: action = env.sample() else: action = env.action_space.sample() # TODO refactor if epsilon > Config.MIN_EPSILON: epsilon -= Config.EPSILON_REDUCTION_PER_STEP # get next environment observation_next, reward, done, info = env.step(action) observation_next = preprocess_frame(observation_next) total_reward += reward # TODO clip reward # update replay memory (FIFO) state_next = np.append(state[1:, :, :], observation_next[np.newaxis, :, :], axis=0) experiences.append((state_next, reward, action, state, done)) if len(experiences) > Config.REPLAY_MEMORY_SIZE: experiences.pop(0) # update network if step % Config.NET_UPDATE_INTERVAL == 0 and len( experiences) > Config.INIT_REPLAY_SIZE: logger.info("update {}".format(step)) batch = random.sample(experiences, Config.BATCH_SIZE) batch_observation_next = np.array([b[0] for b in batch]) batch_reward = np.array([b[1] for b in batch]) batch_action = np.array([b[2] for b in batch]) batch_observation = np.array([b[3] for b in batch]) batch_done = np.array([b[4] for b in batch], dtype=np.float32) batch_reward = batch_reward[:, np.newaxis] batch_action = batch_action[:, np.newaxis] batch_done = batch_done[:, np.newaxis] image.d = batch_observation.astype(np.float32) image_target.d = batch_observation_next.astype(np.float32) a.d = batch_action q_val.forward() # XXX target_q.forward() t.d = batch_reward + (1 - batch_done) * Config.GAMMA * np.max( target_q.d, axis=1, keepdims=True) solver.zero_grad() loss.forward() loss.backward() monitor_loss.add(step, loss.d.copy()) monitor_reward.add(step, total_reward) monitor_q.add(step, np.mean(q.d.copy())) monitor_time.add(step) # TODO weight clip solver.update() logger.info("update done {}".format(step)) # update target network if step % Config.TARGET_NET_UPDATE_INTERVAL == 0: # copy parameter from dqn to target with nn.parameter_scope("dqn"): src = nn.get_parameters() with nn.parameter_scope("target"): dst = nn.get_parameters() for (s_key, s_val), (d_key, d_val) in zip(src.items(), dst.items()): # Variable#d method is reference d_val.d = s_val.d.copy() if step % Config.MODEL_SAVE_INTERVAL == 0: logger.info("save model") nn.save_parameters("model_{}.h5".format(step)) step += 1 observation = observation_next state = state_next
def _build(self): # infer variable self.infer_obs_t = infer_obs_t = nn.Variable((1, 4, 84, 84)) # inference output self.infer_q_t,\ self.infer_probs_t, _ = self.q_function(infer_obs_t, self.num_actions, self.min_v, self.max_v, self.num_bins, 'q_func') self.infer_t = F.sink(self.infer_q_t, self.infer_probs_t) # train variables self.obss_t = nn.Variable((self.batch_size, 4, 84, 84)) self.acts_t = nn.Variable((self.batch_size, 1)) self.rews_tp1 = nn.Variable((self.batch_size, 1)) self.obss_tp1 = nn.Variable((self.batch_size, 4, 84, 84)) self.ters_tp1 = nn.Variable((self.batch_size, 1)) # training output q_t, probs_t, dists = self.q_function(self.obss_t, self.num_actions, self.min_v, self.max_v, self.num_bins, 'q_func') q_tp1, probs_tp1, _ = self.q_function(self.obss_tp1, self.num_actions, self.min_v, self.max_v, self.num_bins, 'target_q_func') expand_last = lambda x: F.reshape(x, x.shape + (1, )) flat = lambda x: F.reshape(x, (-1, 1)) # extract selected dimension a_t_one_hot = expand_last(F.one_hot(self.acts_t, (self.num_actions, ))) probs_t_selected = F.max(probs_t * a_t_one_hot, axis=1) # extract max dimension _, indices = F.max(q_tp1, axis=1, keepdims=True, with_index=True) a_tp1_one_hot = expand_last(F.one_hot(indices, (self.num_actions, ))) probs_tp1_best = F.max(probs_tp1 * a_tp1_one_hot, axis=1) # clipping reward clipped_rews_tp1 = clip_by_value(self.rews_tp1, -1.0, 1.0) disc_q_tp1 = F.reshape(dists, (1, -1)) * (1.0 - self.ters_tp1) t_z = clip_by_value(clipped_rews_tp1 + self.gamma * disc_q_tp1, self.min_v, self.max_v) # update indices b = (t_z - self.min_v) / ((self.max_v - self.min_v) / (self.num_bins - 1)) l = F.floor(b) l_mask = F.reshape(F.one_hot(flat(l), (self.num_bins, )), (-1, self.num_bins, self.num_bins)) u = F.ceil(b) u_mask = F.reshape(F.one_hot(flat(u), (self.num_bins, )), (-1, self.num_bins, self.num_bins)) m_l = expand_last(probs_tp1_best * (1 - (b - l))) m_u = expand_last(probs_tp1_best * (b - l)) m = F.sum(m_l * l_mask + m_u * u_mask, axis=1) m.need_grad = False self.loss = -F.mean(F.sum(m * F.log(probs_t_selected + 1e-10), axis=1)) # optimizer self.solver = S.RMSprop(self.lr, 0.95, 1e-2) # weights and biases with nn.parameter_scope('q_func'): self.params = nn.get_parameters() with nn.parameter_scope('target_q_func'): self.target_params = nn.get_parameters() # set q function parameters to solver self.solver.set_parameters(self.params)
def main(args): # Settings device_id = args.device_id batch_size = args.batch_size batch_size_eval = args.batch_size_eval n_l_train_data = 4000 n_train_data = 50000 n_cls = 10 learning_rate = 1. * 1e-3 n_epoch = 300 act = F.relu iter_epoch = int(n_train_data / batch_size) n_iter = n_epoch * iter_epoch extension_module = args.context alpha = args.alpha # Supervised Model ## ERM batch_size, m, h, w = batch_size, 3, 32, 32 ctx = extension_context(extension_module, device_id=device_id) x_l_0 = nn.Variable((batch_size, m, h, w)) y_l_0 = nn.Variable((batch_size, 1)) pred = cnn_model_003(ctx, x_l_0) loss_ce = ce_loss(ctx, pred, y_l_0) loss_er = er_loss(ctx, pred) loss_supervised = loss_ce + loss_er ## VRM (mixup) x_l_1 = nn.Variable((batch_size, m, h, w)) y_l_1 = nn.Variable((batch_size, 1)) coef = nn.Variable() coef_b = F.broadcast(coef.reshape([1] * x_l_0.ndim, unlink=True), x_l_0.shape) x_l_m = coef_b * x_l_0 + (1 - coef_b) * x_l_1 coef_b = F.broadcast(coef.reshape([1] * pred.ndim, unlink=True), pred.shape) y_l_m = coef_b * F.one_hot(y_l_0, (n_cls, )) \ + (1-coef_b) * F.one_hot(y_l_1, (n_cls, )) x_l_m.need_grad, y_l_m.need_grad = False, False pred_m = cnn_model_003(ctx, x_l_m) loss_er_m = er_loss(ctx, pred_m) #todo: need? loss_ce_m = ce_loss_soft(ctx, pred, y_l_m) loss_supervised_m = loss_ce_m #+ loss_er_m # Semi-Supervised Model ## ERM x_u0 = nn.Variable((batch_size, m, h, w)) x_u1 = nn.Variable((batch_size, m, h, w)) pred_x_u0 = cnn_model_003(ctx, x_u0) pred_x_u1 = cnn_model_003(ctx, x_u1) pred_x_u0.persistent, pred_x_u1.persistent = True, True loss_sr = sr_loss(ctx, pred_x_u0, pred_x_u1) loss_er0 = er_loss(ctx, pred_x_u0) loss_er1 = er_loss(ctx, pred_x_u1) loss_unsupervised = loss_sr + loss_er0 + loss_er1 ## VRM (mixup) x_u2 = nn.Variable((batch_size, m, h, w)) # not to overwrite x_u1.d coef_u = nn.Variable() coef_u_b = F.broadcast(coef_u.reshape([1] * x_u0.ndim, unlink=True), x_u0.shape) x_u_m = coef_u_b * x_u0 + (1 - coef_u_b) * x_u2 pred_x_u0_ = nn.Variable( pred_x_u0.shape) # unlink forward pass but reuse result pred_x_u1_ = nn.Variable(pred_x_u1.shape) pred_x_u0_.data = pred_x_u0.data pred_x_u1_.data = pred_x_u1.data coef_u_b = F.broadcast(coef_u.reshape([1] * pred_x_u0.ndim, unlink=True), pred_x_u0.shape) y_u_m = coef_u_b * pred_x_u0_ + (1 - coef_u_b) * pred_x_u1_ x_u_m.need_grad, y_u_m.need_grad = False, False pred_x_u_m = cnn_model_003(ctx, x_u_m) loss_er_u_m = er_loss(ctx, pred_x_u_m) #todo: need? loss_ce_u_m = ce_loss_soft(ctx, pred_x_u_m, y_u_m) loss_unsupervised_m = loss_ce_u_m #+ loss_er_u_m # Evaluatation Model batch_size_eval, m, h, w = batch_size, 3, 32, 32 x_eval = nn.Variable((batch_size_eval, m, h, w)) pred_eval = cnn_model_003(ctx, x_eval, test=True) # Solver with nn.context_scope(ctx): solver = S.Adam(alpha=learning_rate) solver.set_parameters(nn.get_parameters()) # Dataset ## separate dataset home = os.environ.get("HOME") fpath = os.path.join(home, "datasets/cifar10/cifar-10.npz") separator = Separator(n_l_train_data) separator.separate_then_save(fpath) l_train_path = os.path.join(home, "datasets/cifar10/l_cifar-10.npz") u_train_path = os.path.join(home, "datasets/cifar10/cifar-10.npz") test_path = os.path.join(home, "datasets/cifar10/cifar-10.npz") # data reader data_reader = Cifar10DataReader(l_train_path, u_train_path, test_path, batch_size=batch_size, n_cls=n_cls, da=True, shape=True) # Training loop print("# Training loop") epoch = 1 st = time.time() acc_prev = 0. ve_best = 1. save_path_prev = "" for i in range(n_iter): # Get data and set it to the varaibles x_l0_data, x_l1_data, y_l_data = data_reader.get_l_train_batch() x_u0_data, x_u1_data, y_u_data = data_reader.get_u_train_batch() x_l_0.d, _, y_l_0.d = x_l0_data, x_l1_data, y_l_data x_u0.d, x_u1.d = x_u0_data, x_u1_data # Train ## forward (supervised and its mixup) loss_supervised.forward(clear_no_need_grad=True) coef_data = np.random.beta(alpha, alpha) coef.d = coef_data x_l_1.d = np.random.permutation(x_l0_data) y_l_1.d = np.random.permutation(y_l_data) loss_supervised_m.forward(clear_no_need_grad=True) ## forward (unsupervised and its mixup) loss_unsupervised.forward(clear_no_need_grad=True) coef_data = np.random.beta(alpha, alpha) coef_u.d = coef_data x_u2.d = np.random.permutation(x_u1_data) loss_unsupervised_m.forward(clear_no_need_grad=True) ## backward solver.zero_grad() loss_supervised.backward(clear_buffer=False) loss_supervised_m.backward(clear_buffer=False) loss_unsupervised.backward(clear_buffer=False) loss_unsupervised_m.backward(clear_buffer=True) solver.update() # Evaluate if int((i + 1) % iter_epoch) == 0: # Get data and set it to the varaibles x_data, y_data = data_reader.get_test_batch() # Evaluation loop ve = 0. iter_val = 0 for k in range(0, len(x_data), batch_size_eval): x_eval.d = get_test_data(x_data, k, batch_size_eval) label = get_test_data(y_data, k, batch_size_eval) pred_eval.forward(clear_buffer=True) ve += categorical_error(pred_eval.d, label) iter_val += 1 ve /= iter_val msg = "Epoch:{},ElapsedTime:{},Acc:{:02f}".format( epoch, time.time() - st, (1. - ve) * 100) print(msg) if ve < ve_best: if not os.path.exists(args.model_save_path): os.makedirs(args.model_save_path) if save_path_prev != "": os.remove(save_path_prev) save_path = os.path.join(args.model_save_path, 'params_%06d.h5' % epoch) nn.save_parameters(save_path) save_path_prev = save_path ve_best = ve st = time.time() epoch += 1