def setup_solvers(train_params): solver_dict = dict() solver_generator = S.Adam(alpha=train_params['lr_generator'], beta1=0.5, beta2=0.999) with nn.parameter_scope("generator"): solver_generator.set_parameters(nn.get_parameters()) solver_dict["generator"] = solver_generator solver_discriminator = S.Adam(train_params['lr_discriminator'], beta1=0.5, beta2=0.999) with nn.parameter_scope("discriminator"): solver_discriminator.set_parameters(nn.get_parameters()) solver_dict["discriminator"] = solver_discriminator solver_kp_detector = S.Adam(train_params['lr_kp_detector'], beta1=0.5, beta2=0.999) with nn.parameter_scope("kp_detector"): solver_kp_detector.set_parameters(nn.get_parameters()) solver_dict["kp_detector"] = solver_kp_detector return solver_dict
def _build(self): # inference self.infer_obs_t = nn.Variable((1,) + self.obs_shape) with nn.parameter_scope('trainable'): self.infer_policy_t = policy_network(self.infer_obs_t, self.action_size, 'actor') # training self.obss_t = nn.Variable((self.batch_size,) + self.obs_shape) self.acts_t = nn.Variable((self.batch_size, self.action_size)) self.rews_tp1 = nn.Variable((self.batch_size, 1)) self.obss_tp1 = nn.Variable((self.batch_size,) + self.obs_shape) self.ters_tp1 = nn.Variable((self.batch_size, 1)) # critic training with nn.parameter_scope('trainable'): q_t = q_network(self.obss_t, self.acts_t, 'critic') with nn.parameter_scope('target'): policy_tp1 = policy_network(self.obss_tp1, self.action_size, 'actor') q_tp1 = q_network(self.obss_tp1, policy_tp1, 'critic') y = self.rews_tp1 + self.gamma * q_tp1 * (1.0 - self.ters_tp1) self.critic_loss = F.mean(F.squared_error(q_t, y)) # actor training with nn.parameter_scope('trainable'): policy_t = policy_network(self.obss_t, self.action_size, 'actor') q_t_with_actor = q_network(self.obss_t, policy_t, 'critic') self.actor_loss = -F.mean(q_t_with_actor) # get neural network parameters with nn.parameter_scope('trainable'): with nn.parameter_scope('critic'): critic_params = nn.get_parameters() with nn.parameter_scope('actor'): actor_params = nn.get_parameters() # setup optimizers self.critic_solver = S.Adam(self.critic_lr) self.critic_solver.set_parameters(critic_params) self.actor_solver = S.Adam(self.actor_lr) self.actor_solver.set_parameters(actor_params) with nn.parameter_scope('trainable'): trainable_params = nn.get_parameters() with nn.parameter_scope('target'): target_params = nn.get_parameters() # build target update update_targets = [] sync_targets = [] for key, src in trainable_params.items(): dst = target_params[key] updated_dst = (1.0 - self.tau) * dst + self.tau * src update_targets.append(F.assign(dst, updated_dst)) sync_targets.append(F.assign(dst, src)) self.update_target_expr = F.sink(*update_targets) self.sync_target_expr = F.sink(*sync_targets)
def main(): # argparse args = get_args() # Context Setting # Get context. from nnabla.ext_utils import get_extension_context logger.info("Running in %s" % args.context) ctx = get_extension_context( args.context, device_id=args.device_id) nn.set_default_context(ctx) model_path = args.model if args.train: # Data Loading logger.info("Initialing DataSource.") train_iterator = facade.facade_data_iterator( args.traindir, args.batchsize, shuffle=True, with_memory_cache=False) val_iterator = facade.facade_data_iterator( args.valdir, args.batchsize, random_crop=False, shuffle=False, with_memory_cache=False) monitor = nm.Monitor(args.logdir) solver_gen = S.Adam(alpha=args.lrate, beta1=args.beta1) solver_dis = S.Adam(alpha=args.lrate, beta1=args.beta1) generator = unet.generator discriminator = unet.discriminator model_path = train(generator, discriminator, args.patch_gan, solver_gen, solver_dis, args.weight_l1, train_iterator, val_iterator, args.epoch, monitor, args.monitor_interval) if args.generate: if model_path is not None: # Data Loading logger.info("Generating from DataSource.") test_iterator = facade.facade_data_iterator( args.testdir, args.batchsize, shuffle=False, with_memory_cache=False) generator = unet.generator generate(generator, model_path, test_iterator, args.logdir) else: logger.error("Trained model was NOT given.")
def main(): parser = argparse.ArgumentParser() parser.add_argument('--config', default=None, type=str) parser.add_argument('--info', default=None, type=str) args = parser.parse_args() config = load_decoder_config(args.config) if args.info: config["experiment_name"] += args.info pprint.pprint(config) ######################### # Context Setting # Get context. from nnabla.ext_utils import get_extension_context logger.info(f'Running in {config["context"]}.') ctx = get_extension_context(config["context"], device_id=config["device_id"]) nn.set_default_context(ctx) ######################### # Data Loading logger.info('Initialing Datasource') train_iterator = data.celebv_data_iterator( dataset_mode="decoder", celeb_name=config["trg_celeb_name"], data_dir=config["train_dir"], ref_dir=config["ref_dir"], mode=config["mode"], batch_size=config["train"]["batch_size"], shuffle=config["train"]["shuffle"], with_memory_cache=config["train"]["with_memory_cache"], with_file_cache=config["train"]["with_file_cache"], ) monitor = nm.Monitor( os.path.join(config["logdir"], "decoder", config["trg_celeb_name"], config["experiment_name"])) # Optimizer solver_netG = S.Adam(alpha=config["train"]["lr"], beta1=config["train"]["beta1"]) solver_netD = S.Adam(alpha=config["train"]["lr"], beta1=config["train"]["beta1"]) # Network netG = models.netG_decoder netD = models.netD_decoder train(config, netG, netD, solver_netG, solver_netD, train_iterator, monitor)
def create_network(batch_size, num_dilations, learning_rate): # model x = nn.Variable(shape=(batch_size, data_config.duration, 1)) # (B, T, 1) onehot = F.one_hot(x, shape=(data_config.q_bit_len, )) # (B, T, C) wavenet_input = F.transpose(onehot, (0, 2, 1)) # (B, C, T) # speaker embedding s_emb = None net = WaveNet(num_dilations) wavenet_output = net(wavenet_input, s_emb) pred = F.transpose(wavenet_output, (0, 2, 1)) # (B, T, 1) t = nn.Variable(shape=(batch_size, data_config.duration, 1)) loss = F.mean(F.softmax_cross_entropy(pred, t)) # loss.visit(PrintFunc()) # Create Solver. solver = S.Adam(learning_rate) solver.set_parameters(nn.get_parameters()) return x, t, loss, solver
def main(): # 入力データのshape定義 x = nn.variable.Variable( [BATCH_SIZE, IMAGE_DEPTH * IMAGE_WIDTH * IMAGE_HEIGHT]) # ラベルのshape定義 t = nn.variable.Variable([BATCH_SIZE, LABEL_NUM]) pred = convolution(x) loss_ = loss(pred, t) solver = S.Adam() solver.set_parameters(nn.get_parameters()) data = InputData() for i in range(NUM_STEP): # 100STEP毎にテスト実施 if i % 100 == 0: l = 0 a = 0 for k, (t.d, x.d) in enumerate(data.test_data()): loss_.forward() l += loss_.d a += accuracy(pred, t) print("Step: %05d Test loss: %0.05f Test accuracy: %0.05f" % (i, l / k, a / k)) t.d, x.d = data.next_batch() loss_.forward() solver.zero_grad() loss_.backward() solver.weight_decay(DECAY_RATE) solver.update() if i % 10 == 0: print("Step: %05d Train loss: %0.05f Train accuracy: %0.05f" % (i, loss_.d, accuracy(pred, t)))
def build_train_graph(self, batch): self.solver = S.Adam(self.learning_rate) obs, action, reward, terminal, newobs = batch # Create input variables s = nn.Variable(obs.shape) a = nn.Variable(action.shape) r = nn.Variable(reward.shape) t = nn.Variable(terminal.shape) snext = nn.Variable(newobs.shape) with nn.parameter_scope(self.name_q): q = self.q_builder(s, self.num_actions, test=False) self.solver.set_parameters(nn.get_parameters()) with nn.parameter_scope(self.name_qnext): qnext = self.q_builder(snext, self.num_actions, test=True) qnext.need_grad = False clipped_r = F.minimum_scalar(F.maximum_scalar( r, -self.clip_reward), self.clip_reward) q_a = F.sum( q * F.one_hot(F.reshape(a, (-1, 1), inplace=False), (q.shape[1],)), axis=1) target = clipped_r + self.gamma * (1 - t) * F.max(qnext, axis=1) loss = F.mean(F.huber_loss(q_a, target)) Variables = namedtuple( 'Variables', ['s', 'a', 'r', 't', 'snext', 'q', 'loss']) self.v = Variables(s, a, r, t, snext, q, loss) self.sync_models() self.built = True
def _setup_solvers(self): # prepare training parameters with nn.parameter_scope('%s/discriminator' % self.scope): d_params = nn.get_parameters() with nn.parameter_scope('%s/generator' % self.scope): g_params = nn.get_parameters() # create solver for discriminator self.d_lr_scheduler = StepScheduler(self.d_lr, self.gamma, [self.lr_milestone]) self.d_solver = S.Adam(self.d_lr, beta1=self.beta1, beta2=0.999) self.d_solver.set_parameters(d_params) # create solver for generator self.g_lr_scheduler = StepScheduler(self.g_lr, self.gamma, [self.lr_milestone]) self.g_solver = S.Adam(self.g_lr, beta1=self.beta1, beta2=0.999) self.g_solver.set_parameters(g_params)
def train_model(model, data, labels): model(data) solver = S.Adam(alpha=0.001, beta1=0.9, beta2=0.999, eps=1e-08) solver.set_parameters(nn.get_parameters()) for i in range(1500): out = model(data) loss = F.categorical_cross_entropy(out, labels) loss.forward() solver.zero_grad() loss.backward() solver.update()
def train(self): # variables for training tx_in = nn.Variable( [self._batch_size, self._x_input_length, self._cols_size]) tx_out = nn.Variable( [self._batch_size, self._x_output_length, self._cols_size]) tpred = self.network(tx_in, self._lstm_unit_name, self._lstm_units) tpred.persistent = True loss = F.mean(F.squared_error(tpred, tx_out)) solver = S.Adam(self._learning_rate) solver.set_parameters(nn.get_parameters()) # variables for validation vx_in = nn.Variable( [self._batch_size, self._x_input_length, self._cols_size]) vx_out = nn.Variable( [self._batch_size, self._x_output_length, self._cols_size]) vpred = self.network(vx_in, self._lstm_unit_name, self._lstm_units) # data iterators tdata = self._load_dataset(self._training_dataset_path, self._batch_size, shuffle=True) vdata = self._load_dataset(self._validation_dataset_path, self._batch_size, shuffle=True) # monitors from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed monitor = Monitor(self._monitor_path) monitor_loss = MonitorSeries("Training loss", monitor, interval=10) monitor_err = MonitorSeries("Training error", monitor, interval=10) monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100) monitor_verr = MonitorSeries("Validation error", monitor, interval=10) # Training loop for i in range(self._max_iter): if i % self._val_interval == 0: ve = self._validate(vpred, vx_in, vx_out, vdata, self._val_iter) monitor_verr.add(i, ve / self._val_iter) te = self._train(tpred, solver, loss, tx_in, tx_out, tdata.next(), self._weight_decay) monitor_loss.add(i, loss.d.copy()) monitor_err.add(i, te) monitor_time.add(i) ve = self._validate(vpred, vx_in, vx_out, vdata, self._val_iter) monitor_verr.add(i, ve / self._val_iter) # Save a best model parameters nn.save_parameters(self._model_params_path)
def __init__(self, batch_size=32, learning_rate=1e-4, max_iter=5086, total_epochs=20, monitor_path=None, val_weight=None, model_load_path=None): """ Construct all the necessary attributes for the attribute classifier. Args: batch_size (int): number of samples contained in each generated batch learning_rate (float) : learning rate max_iter (int) : maximum iterations for an epoch total_epochs (int) : total epochs to train the model val_weight : sample weights monitor_path (str) : model parameter to be saved model_load_path (str) : load the model """ self.batch_size = batch_size # Resnet 50 # training graph model = ResNet50() self.input_image = nn.Variable((self.batch_size, ) + model.input_shape) self.label = nn.Variable([self.batch_size, 1]) # fine tuning pool = model(self.input_image, training=True, use_up_to='pool') self.clf = clf_resnet50(pool) self.clf.persistent = True # loss self.loss = F.mean(F.sigmoid_cross_entropy(self.clf, self.label)) # hyper parameters self.solver = S.Adam(learning_rate) self.solver.set_parameters(nn.get_parameters()) # validation graph self.x_v = nn.Variable((self.batch_size, ) + model.input_shape) pool_v = model(self.x_v, training=False, use_up_to='pool') self.v_clf = clf_resnet50(pool_v, train=False) self.v_clf_out = F.sigmoid(self.v_clf) self.print_freq = 100 self.validation_weight = val_weight # val params self.acc = 0.0 self.total_epochs = total_epochs self.max_iter = max_iter self.monitor_path = monitor_path if model_load_path is not None: _ = nn.load_parameters(model_load_path)
def data_distill(model, uniform_data_iterator, num_iter): generated_img = [] for _ in range(uniform_data_iterator.size // uniform_data_iterator.batch_size): img, _ = uniform_data_iterator.next() dst_img = nn.Variable(img.shape, need_grad=True) dst_img.d = img img_params = OrderedDict() img_params['img'] = dst_img init_lr = 0.5 solver = S.Adam(alpha=init_lr) solver.set_parameters(img_params) #scheduler = lr_scheduler.CosineScheduler(init_lr=0.5, max_iter=num_iter) scheduler = ReduceLROnPlateauScheduler(init_lr=init_lr, min_lr=1e-4, verbose=False, patience=100) dummy_solver = S.Sgd(lr=0) dummy_solver.set_parameters(nn.get_parameters()) for it in tqdm(range(num_iter)): lr = scheduler.get_learning_rate() solver.set_learning_rate(lr) global outs outs = [] global batch_stats batch_stats = [] y = model(denormalize(dst_img), force_global_pooling=True, training=False) # denormalize to U(0, 255) y.forward(function_post_hook=get_output) assert len(outs) == len(batch_stats) loss = zeroq_loss(batch_stats, outs, dst_img) loss.forward() solver.zero_grad() dummy_solver.zero_grad() loss.backward() solver.weight_decay(1e-6) solver.update() scheduler.update_lr(loss.d) generated_img.append(dst_img.d) return generated_img
def __init__(self, generator, args): self.generator = generator self.solver = S.Adam() self.base_lr = 0.1 self.img_size = 1024 self.n_latent = 10000 self.num_iters = 500 self.latent_dim = self.generator.mapping_network_dim self.mse_c = 0.0 self.n_c = 1e5 self.lpips_distance = LPIPS(model='vgg') self.project(args)
def train(max_iter=60000): # Initialize data provider di_l = I.data_iterator_mnist(batch_size, True) di_t = I.data_iterator_mnist(batch_size, False) # Network shape_x = (1, 28, 28) shape_z = (50, ) x = nn.Variable((batch_size, ) + shape_x) loss_l = I.vae(x, shape_z, test=False) loss_t = I.vae(x, shape_z, test=True) # Create solver solver = S.Adam(learning_rate) solver.set_parameters(nn.get_parameters()) # Monitors for training and validation path = cache_dir(os.path.join(I.name, "monitor")) monitor = M.Monitor(path) monitor_train_loss = M.MonitorSeries("train_loss", monitor, interval=600) monitor_val_loss = M.MonitorSeries("val_loss", monitor, interval=600) monitor_time = M.MonitorTimeElapsed("time", monitor, interval=600) # Training Loop. for i in range(max_iter): # Initialize gradients solver.zero_grad() # Forward, backward and update x.d, _ = di_l.next() loss_l.forward(clear_no_need_grad=True) loss_l.backward(clear_buffer=True) solver.weight_decay(weight_decay) solver.update() # Forward for test x.d, _ = di_t.next() loss_t.forward(clear_no_need_grad=True) # Monitor for logging monitor_train_loss.add(i, loss_l.d.copy()) monitor_val_loss.add(i, loss_t.d.copy()) monitor_time.add(i) return path
def main(): # Context ctx = get_extension_context("cudnn", device_id="0") nn.set_default_context(ctx) nn.auto_forward(False) # Inputs b, c, h, w = 64, 1, 28, 28 x = nn.Variable([b, c, h, w]) t = nn.Variable([b, 1]) vx = nn.Variable([b, c, h, w]) vt = nn.Variable([b, 1]) # Model model = Model() pred = model(x) loss = F.softmax_cross_entropy(pred, t) vpred = model(vx, test=True) verror = F.top_n_error(vpred, vt) # Solver solver = S.Adam() solver.set_parameters(model.get_parameters(grad_only=True)) # Data Iterator tdi = data_iterator_mnist(b, train=True) vdi = data_iterator_mnist(b, train=False) # Monitor monitor = Monitor("tmp.monitor") monitor_loss = MonitorSeries("Training loss", monitor, interval=10) monitor_verr = MonitorSeries("Test error", monitor, interval=1) # Training loop for e in range(1): for j in range(tdi.size // b): i = e * tdi.size // b + j x.d, t.d = tdi.next() solver.zero_grad() loss.forward(clear_no_need_grad=True) loss.backward(clear_buffer=True) solver.update() monitor_loss.add(i, loss.d) error = 0.0 for _ in range(vdi.size // b): vx.d, vt.d = vdi.next() verror.forward(clear_buffer=True) error += verror.d error /= vdi.size // b monitor_verr.add(i, error)
def create_model(net_name: str, batch_size=128, learning_rate=0.001): nn.clear_parameters() net = namedtuple("net", ("image", "label", "pred", "loss")) model = namedtuple("model", ("name", "train", "val", "solver")) mnist_cnn_prediction = get_net_prediction(net_name) def create_net(test, persistent): image = nn.Variable([batch_size, 1, 28, 28]) label = nn.Variable([batch_size, 1]) norm = 255 if net_name.startswith("binary") else 1 pred = mnist_cnn_prediction(image / norm, test=test) if persistent: pred.persistent = True loss = F.mean(F.softmax_cross_entropy(pred, label)) return net(image, label, pred, loss) net_train = create_net(test=False, persistent=True) net_val = create_net(test=True, persistent=False) solver = S.Adam(learning_rate) solver.set_parameters(nn.get_parameters()) return model(net_name, net_train, net_val, solver)
def train(max_iter=5000, learning_rate=0.001, weight_decay=0): train = create_net(False) test = create_net(True) # ソルバーの作成 solver = S.Adam(learning_rate) solver.set_parameters(nn.get_parameters()) # モニタの作成 path = cache_dir(os.path.join(I.name, "monitor")) monitor = M.Monitor(path) monitor_loss_train = M.MonitorSeries("training_loss", monitor, interval=100) monitor_time = M.MonitorTimeElapsed("time", monitor, interval=100) monitor_loss_val = M.MonitorSeries("val_loss", monitor, interval=100) # 訓練の実行 for i in range(max_iter): if (i + 1) % 100 == 0: val_error = 0.0 val_iter = 10 for j in range(val_iter): test.image0.d, test.image1.d, test.label.d = test.data.next() test.loss.forward(clear_buffer=True) val_error += test.loss.d monitor_loss_val.add(i, val_error / val_iter) train.image0.d, train.image1.d, train.label.d = train.data.next() solver.zero_grad() train.loss.forward(clear_no_need_grad=True) train.loss.backward(clear_buffer=True) solver.weight_decay(weight_decay) solver.update() monitor_loss_train.add(i, train.loss.d.copy()) monitor_time.add(i) nn.save_parameters(os.path.join(path, "params.h5")) return path
def classification_svd(): args = get_args() # Get context. from nnabla.ext_utils import get_extension_context logger.info("Running in %s" % args.context) ctx = get_extension_context(args.context, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) # Create CNN network for both training and testing. mnist_cnn_prediction = mnist_lenet_prediction_slim # TRAIN reference = "reference" slim = "slim" rrate = 0.5 # reduction rate # Create input variables. image = nn.Variable([args.batch_size, 1, 28, 28]) label = nn.Variable([args.batch_size, 1]) # Create `reference` and "slim" prediction graph. model_load_path = args.model_load_path pred = mnist_cnn_prediction(image, scope=slim, rrate=rrate, test=False) pred.persistent = True # Decompose and set parameters decompose_network_and_set_params(model_load_path, reference, slim, rrate) loss = F.mean(F.softmax_cross_entropy(pred, label)) # TEST # Create input variables. vimage = nn.Variable([args.batch_size, 1, 28, 28]) vlabel = nn.Variable([args.batch_size, 1]) # Create reference prediction graph. vpred = mnist_cnn_prediction(vimage, scope=slim, rrate=rrate, test=True) # Create Solver. solver = S.Adam(args.learning_rate) with nn.parameter_scope(slim): solver.set_parameters(nn.get_parameters()) # Create monitor. from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed monitor = Monitor(args.monitor_path) monitor_loss = MonitorSeries("Training loss", monitor, interval=10) monitor_err = MonitorSeries("Training error", monitor, interval=10) monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100) monitor_verr = MonitorSeries("Test error", monitor, interval=10) # Initialize DataIterator for MNIST. data = data_iterator_mnist(args.batch_size, True) vdata = data_iterator_mnist(args.batch_size, False) best_ve = 1.0 # Training loop. for i in range(args.max_iter): if i % args.val_interval == 0: # Validation ve = 0.0 for j in range(args.val_iter): vimage.d, vlabel.d = vdata.next() vpred.forward(clear_buffer=True) ve += categorical_error(vpred.d, vlabel.d) monitor_verr.add(i, ve / args.val_iter) if ve < best_ve: nn.save_parameters( os.path.join(args.model_save_path, 'params_%06d.h5' % i)) best_ve = ve # Training forward image.d, label.d = data.next() solver.zero_grad() loss.forward(clear_no_need_grad=True) loss.backward(clear_buffer=True) solver.weight_decay(args.weight_decay) solver.update() e = categorical_error(pred.d, label.d) monitor_loss.add(i, loss.d.copy()) monitor_err.add(i, e) monitor_time.add(i) ve = 0.0 for j in range(args.val_iter): vimage.d, vlabel.d = vdata.next() vpred.forward(clear_buffer=True) ve += categorical_error(vpred.d, vlabel.d) monitor_verr.add(i, ve / args.val_iter) parameter_file = os.path.join(args.model_save_path, 'params_{:06}.h5'.format(args.max_iter)) nn.save_parameters(parameter_file)
def main(args): # Settings device_id = args.device_id batch_size = args.batch_size batch_size_eval = args.batch_size_eval n_l_train_data = 4000 n_train_data = 50000 n_cls = 10 learning_rate = 1. * 1e-3 n_epoch = 300 act = F.relu iter_epoch = n_train_data / batch_size n_iter = n_epoch * iter_epoch extension_module = args.context lambda_ = args.lambda_ # Model ## supervised cnn batch_size, m, h, w = batch_size, 3, 32, 32 ctx = extension_context(extension_module, device_id=device_id) x_l = nn.Variable((batch_size, m, h, w)) y_l = nn.Variable((batch_size, 1)) pred, log_var = cnn_model_003(ctx, "cnn", x_l) one = F.constant(1., log_var.shape) loss_ce = ce_loss(ctx, pred, y_l) reg_sigma = sigma_regularization(ctx, log_var, one) loss_supervised = loss_ce + er_loss(ctx, pred) + lambda_ * reg_sigma ## supervised resnet pred_res = cifar10_resnet23_prediction(ctx, "resnet", x_l) loss_res_ce = ce_loss(ctx, pred_res, y_l) loss_res_supervised = loss_res_ce ## stochastic regularization x_u0 = nn.Variable((batch_size, m, h, w)) x_u0.persistent = True x_u1 = nn.Variable((batch_size, m, h, w)) pred_x_u0, log_var0 = cnn_model_003(ctx, "cnn", x_u0) pred_x_u0.persistent = True pred_x_u1, log_var1 = cnn_model_003(ctx, "cnn", x_u1) loss_sr = sr_loss_with_uncertainty(ctx, pred_x_u0, pred_x_u1, log_var0, log_var1) reg_sigma0 = sigma_regularization(ctx, log_var0, one) reg_sigma1 = sigma_regularization(ctx, log_var1, one) reg_sigmas = sigmas_regularization(ctx, log_var0, log_var1) loss_unsupervised = loss_sr + er_loss(ctx, pred_x_u0) + er_loss(ctx, pred_x_u1) \ + lambda_ * (reg_sigma0 + reg_sigma1) + lambda_ * reg_sigmas ## knowledge transfer for resnet pred_res_x_u0 = cifar10_resnet23_prediction(ctx, "resnet", x_u0) loss_res_unsupervised = kl_divergence(ctx, pred_res_x_u0, pred_x_u0, log_var0) ## evaluate batch_size_eval, m, h, w = batch_size, 3, 32, 32 x_eval = nn.Variable((batch_size_eval, m, h, w)) x_eval.persistent = True pred_eval, _ = cnn_model_003(ctx, "cnn", x_eval, test=True) pred_res_eval = cifar10_resnet23_prediction(ctx, "resnet", x_eval, test=True) # Solver with nn.context_scope(ctx): # Solver with nn.context_scope(ctx): with nn.parameter_scope("cnn"): solver = S.Adam(alpha=learning_rate) solver.set_parameters(nn.get_parameters()) with nn.parameter_scope("resnet"): solver_res = S.Adam(alpha=learning_rate) solver_res.set_parameters(nn.get_parameters()) # Dataset ## separate dataset home = os.environ.get("HOME") fpath = os.path.join(home, "datasets/cifar10/cifar-10.npz") separator = Separator(n_l_train_data) separator.separate_then_save(fpath) l_train_path = os.path.join(home, "datasets/cifar10/l_cifar-10.npz") u_train_path = os.path.join(home, "datasets/cifar10/cifar-10.npz") test_path = os.path.join(home, "datasets/cifar10/cifar-10.npz") # data reader data_reader = Cifar10DataReader(l_train_path, u_train_path, test_path, batch_size=batch_size, n_cls=n_cls, da=True, shape=True) # Training loop print("# Training loop") epoch = 1 st = time.time() acc_prev = 0. for i in range(n_iter): # Get data and set it to the varaibles x_l0_data, x_l1_data, y_l_data = data_reader.get_l_train_batch() x_u0_data, x_u1_data, y_u_data = data_reader.get_u_train_batch() x_l.d, _, y_l.d = x_l0_data, x_l1_data, y_l_data x_u0.d, x_u1.d = x_u0_data, x_u1_data # Train cnn loss_supervised.forward(clear_no_need_grad=True) loss_unsupervised.forward(clear_no_need_grad=True) solver.zero_grad() loss_supervised.backward(clear_buffer=True) loss_unsupervised.backward(clear_buffer=True) solver.update() # Train resnet loss_res_supervised.forward(clear_no_need_grad=True) loss_res_unsupervised.forward(clear_no_need_grad=True) solver_res.zero_grad() loss_res_supervised.backward(clear_buffer=True) pred_x_u0.need_grad = False # no need grad for teacher loss_res_unsupervised.backward(clear_buffer=True) solver_res.update() pred_x_u0.need_grad = True # Evaluate if (i + 1) % iter_epoch == 0: # Get data and set it to the varaibles x_data, y_data = data_reader.get_test_batch() # Evaluation loop for cnn ve = 0. iter_val = 0 for k in range(0, len(x_data), batch_size_eval): x_eval.d = get_test_data(x_data, k, batch_size_eval) label = get_test_data(y_data, k, batch_size_eval) pred_eval.forward(clear_buffer=True) ve += categorical_error(pred_eval.d, label) iter_val += 1 msg = "Epoch:{},ElapsedTime:{},Acc:{:02f}".format( epoch, time.time() - st, (1. - ve / iter_val) * 100) print(msg) # Evaluation loop for resnet ve = 0. iter_val = 0 for k in range(0, len(x_data), batch_size_eval): x_eval.d = get_test_data(x_data, k, batch_size_eval) label = get_test_data(y_data, k, batch_size_eval) pred_res_eval.forward(clear_buffer=True) ve += categorical_error(pred_res_eval.d, label) iter_val += 1 msg = "Model:resnet,Epoch:{},ElapsedTime:{},Acc:{:02f}".format( epoch, time.time() - st, (1. - ve / iter_val) * 100) print(msg) st = time.time() epoch += 1
def train(args): if args.c_dim != len(args.selected_attrs): print("c_dim must be the same as the num of selected attributes. Modified c_dim.") args.c_dim = len(args.selected_attrs) # Dump the config information. config = dict() print("Used config:") for k in args.__dir__(): if not k.startswith("_"): config[k] = getattr(args, k) print("'{}' : {}".format(k, getattr(args, k))) # Prepare Generator and Discriminator based on user config. generator = functools.partial( model.generator, conv_dim=args.g_conv_dim, c_dim=args.c_dim, num_downsample=args.num_downsample, num_upsample=args.num_upsample, repeat_num=args.g_repeat_num) discriminator = functools.partial(model.discriminator, image_size=args.image_size, conv_dim=args.d_conv_dim, c_dim=args.c_dim, repeat_num=args.d_repeat_num) x_real = nn.Variable( [args.batch_size, 3, args.image_size, args.image_size]) label_org = nn.Variable([args.batch_size, args.c_dim, 1, 1]) label_trg = nn.Variable([args.batch_size, args.c_dim, 1, 1]) with nn.parameter_scope("dis"): dis_real_img, dis_real_cls = discriminator(x_real) with nn.parameter_scope("gen"): x_fake = generator(x_real, label_trg) x_fake.persistent = True # to retain its value during computation. # get an unlinked_variable of x_fake x_fake_unlinked = x_fake.get_unlinked_variable() with nn.parameter_scope("dis"): dis_fake_img, dis_fake_cls = discriminator(x_fake_unlinked) # ---------------- Define Loss for Discriminator ----------------- d_loss_real = (-1) * loss.gan_loss(dis_real_img) d_loss_fake = loss.gan_loss(dis_fake_img) d_loss_cls = loss.classification_loss(dis_real_cls, label_org) d_loss_cls.persistent = True # Gradient Penalty. alpha = F.rand(shape=(args.batch_size, 1, 1, 1)) x_hat = F.mul2(alpha, x_real) + \ F.mul2(F.r_sub_scalar(alpha, 1), x_fake_unlinked) with nn.parameter_scope("dis"): dis_for_gp, _ = discriminator(x_hat) grads = nn.grad([dis_for_gp], [x_hat]) l2norm = F.sum(grads[0] ** 2.0, axis=(1, 2, 3)) ** 0.5 d_loss_gp = F.mean((l2norm - 1.0) ** 2.0) # total discriminator loss. d_loss = d_loss_real + d_loss_fake + args.lambda_cls * \ d_loss_cls + args.lambda_gp * d_loss_gp # ---------------- Define Loss for Generator ----------------- g_loss_fake = (-1) * loss.gan_loss(dis_fake_img) g_loss_cls = loss.classification_loss(dis_fake_cls, label_trg) g_loss_cls.persistent = True # Reconstruct Images. with nn.parameter_scope("gen"): x_recon = generator(x_fake_unlinked, label_org) x_recon.persistent = True g_loss_rec = loss.recon_loss(x_real, x_recon) g_loss_rec.persistent = True # total generator loss. g_loss = g_loss_fake + args.lambda_rec * \ g_loss_rec + args.lambda_cls * g_loss_cls # -------------------- Solver Setup --------------------- d_lr = args.d_lr # initial learning rate for Discriminator g_lr = args.g_lr # initial learning rate for Generator solver_dis = S.Adam(alpha=args.d_lr, beta1=args.beta1, beta2=args.beta2) solver_gen = S.Adam(alpha=args.g_lr, beta1=args.beta1, beta2=args.beta2) # register parameters to each solver. with nn.parameter_scope("dis"): solver_dis.set_parameters(nn.get_parameters()) with nn.parameter_scope("gen"): solver_gen.set_parameters(nn.get_parameters()) # -------------------- Create Monitors -------------------- monitor = Monitor(args.monitor_path) monitor_d_cls_loss = MonitorSeries( 'real_classification_loss', monitor, args.log_step) monitor_g_cls_loss = MonitorSeries( 'fake_classification_loss', monitor, args.log_step) monitor_loss_dis = MonitorSeries( 'discriminator_loss', monitor, args.log_step) monitor_recon_loss = MonitorSeries( 'reconstruction_loss', monitor, args.log_step) monitor_loss_gen = MonitorSeries('generator_loss', monitor, args.log_step) monitor_time = MonitorTimeElapsed("Training_time", monitor, args.log_step) # -------------------- Prepare / Split Dataset -------------------- using_attr = args.selected_attrs dataset, attr2idx, idx2attr = get_data_dict(args.attr_path, using_attr) random.seed(313) # use fixed seed. random.shuffle(dataset) # shuffle dataset. test_dataset = dataset[-2000:] # extract 2000 images for test if args.num_data: # Use training data partially. training_dataset = dataset[:min(args.num_data, len(dataset) - 2000)] else: training_dataset = dataset[:-2000] print("Use {} images for training.".format(len(training_dataset))) # create data iterators. load_func = functools.partial(stargan_load_func, dataset=training_dataset, image_dir=args.celeba_image_dir, image_size=args.image_size, crop_size=args.celeba_crop_size) data_iterator = data_iterator_simple(load_func, len( training_dataset), args.batch_size, with_file_cache=False, with_memory_cache=False) load_func_test = functools.partial(stargan_load_func, dataset=test_dataset, image_dir=args.celeba_image_dir, image_size=args.image_size, crop_size=args.celeba_crop_size) test_data_iterator = data_iterator_simple(load_func_test, len( test_dataset), args.batch_size, with_file_cache=False, with_memory_cache=False) # Keep fixed test images for intermediate translation visualization. test_real_ndarray, test_label_ndarray = test_data_iterator.next() test_label_ndarray = test_label_ndarray.reshape( test_label_ndarray.shape + (1, 1)) # -------------------- Training Loop -------------------- one_epoch = data_iterator.size // args.batch_size num_max_iter = args.max_epoch * one_epoch for i in range(num_max_iter): # Get real images and labels. real_ndarray, label_ndarray = data_iterator.next() label_ndarray = label_ndarray.reshape(label_ndarray.shape + (1, 1)) label_ndarray = label_ndarray.astype(float) x_real.d, label_org.d = real_ndarray, label_ndarray # Generate target domain labels randomly. rand_idx = np.random.permutation(label_org.shape[0]) label_trg.d = label_ndarray[rand_idx] # ---------------- Train Discriminator ----------------- # generate fake image. x_fake.forward(clear_no_need_grad=True) d_loss.forward(clear_no_need_grad=True) solver_dis.zero_grad() d_loss.backward(clear_buffer=True) solver_dis.update() monitor_loss_dis.add(i, d_loss.d.item()) monitor_d_cls_loss.add(i, d_loss_cls.d.item()) monitor_time.add(i) # -------------- Train Generator -------------- if (i + 1) % args.n_critic == 0: g_loss.forward(clear_no_need_grad=True) solver_dis.zero_grad() solver_gen.zero_grad() x_fake_unlinked.grad.zero() g_loss.backward(clear_buffer=True) x_fake.backward(grad=None) solver_gen.update() monitor_loss_gen.add(i, g_loss.d.item()) monitor_g_cls_loss.add(i, g_loss_cls.d.item()) monitor_recon_loss.add(i, g_loss_rec.d.item()) monitor_time.add(i) if (i + 1) % args.sample_step == 0: # save image. save_results(i, args, x_real, x_fake, label_org, label_trg, x_recon) if args.test_during_training: # translate images from test dataset. x_real.d, label_org.d = test_real_ndarray, test_label_ndarray label_trg.d = test_label_ndarray[rand_idx] x_fake.forward(clear_no_need_grad=True) save_results(i, args, x_real, x_fake, label_org, label_trg, None, is_training=False) # Learning rates get decayed if (i + 1) > int(0.5 * num_max_iter) and (i + 1) % args.lr_update_step == 0: g_lr = max(0, g_lr - (args.lr_update_step * args.g_lr / float(0.5 * num_max_iter))) d_lr = max(0, d_lr - (args.lr_update_step * args.d_lr / float(0.5 * num_max_iter))) solver_gen.set_learning_rate(g_lr) solver_dis.set_learning_rate(d_lr) print('learning rates decayed, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr)) # Save parameters and training config. param_name = 'trained_params_{}.h5'.format( datetime.datetime.today().strftime("%m%d%H%M")) param_path = os.path.join(args.model_save_path, param_name) nn.save_parameters(param_path) config["pretrained_params"] = param_name with open(os.path.join(args.model_save_path, "training_conf_{}.json".format(datetime.datetime.today().strftime("%m%d%H%M"))), "w") as f: json.dump(config, f) # -------------------- Translation on test dataset -------------------- for i in range(args.num_test): real_ndarray, label_ndarray = test_data_iterator.next() label_ndarray = label_ndarray.reshape(label_ndarray.shape + (1, 1)) label_ndarray = label_ndarray.astype(float) x_real.d, label_org.d = real_ndarray, label_ndarray rand_idx = np.random.permutation(label_org.shape[0]) label_trg.d = label_ndarray[rand_idx] x_fake.forward(clear_no_need_grad=True) save_results(i, args, x_real, x_fake, label_org, label_trg, None, is_training=False)
def train(data_iterator, monitor, config, comm, args): monitor_train_loss, monitor_train_recon = None, None monitor_val_loss, monitor_val_recon = None, None if comm.rank == 0: monitor_train_loss = MonitorSeries( config['monitor']['train_loss'], monitor, interval=config['train']['logger_step_interval']) monitor_train_recon = MonitorImageTile(config['monitor']['train_recon'], monitor, interval=config['train']['logger_step_interval'], num_images=config['train']['batch_size']) monitor_val_loss = MonitorSeries( config['monitor']['val_loss'], monitor, interval=config['train']['logger_step_interval']) monitor_val_recon = MonitorImageTile(config['monitor']['val_recon'], monitor, interval=config['train']['logger_step_interval'], num_images=config['train']['batch_size']) model = VQVAE(config) if not args.sample_from_pixelcnn: if config['train']['solver'] == 'adam': solver = S.Adam() else: solver = S.momentum() solver.set_learning_rate(config['train']['learning_rate']) train_loader = data_iterator(config, comm, train=True) if config['dataset']['name'] != 'imagenet': val_loader = data_iterator(config, comm, train=False) else: val_loader = None else: solver, train_loader, val_loader = None, None, None if not args.pixelcnn_prior: trainer = VQVAEtrainer(model, solver, train_loader, val_loader, monitor_train_loss, monitor_train_recon, monitor_val_loss, monitor_val_recon, config, comm) num_epochs = config['train']['num_epochs'] else: pixelcnn_model = GatedPixelCNN(config['prior']) trainer = TrainerPrior(model, pixelcnn_model, solver, train_loader, val_loader, monitor_train_loss, monitor_train_recon, monitor_val_loss, monitor_val_recon, config, comm, eval=args.sample_from_pixelcnn) num_epochs = config['prior']['train']['num_epochs'] if os.path.exists(config['model']['checkpoint']) and (args.load_checkpoint or args.sample_from_pixelcnn): checkpoint_path = config['model']['checkpoint'] if not args.pixelcnn_prior else config['prior']['checkpoint'] trainer.load_checkpoint(checkpoint_path, msg='Parameters loaded from {}'.format( checkpoint_path), pixelcnn=args.pixelcnn_prior, load_solver=not args.sample_from_pixelcnn) if args.sample_from_pixelcnn: trainer.random_generate( args.sample_from_pixelcnn, args.sample_save_path) return for epoch in range(num_epochs): trainer.train(epoch) if epoch % config['val']['interval'] == 0 and val_loader != None: trainer.validate(epoch) if comm.rank == 0: if epoch % config['train']['save_param_step_interval'] == 0 or epoch == config['train']['num_epochs']-1: trainer.save_checkpoint( config['model']['saved_models_dir'], epoch, pixelcnn=args.pixelcnn_prior)
def main(args): # Settings device_id = args.device_id batch_size = 100 batch_size_eval = 100 n_l_train_data = 4000 n_train_data = 50000 n_cls = 10 learning_rate = 1. * 1e-3 n_epoch = 300 act = F.relu iter_epoch = n_train_data / batch_size n_iter = n_epoch * iter_epoch extension_module = args.context # Model ## supervised batch_size, m, h, w = batch_size, 3, 32, 32 ctx = extension_context(extension_module, device_id=device_id) x_l = nn.Variable((batch_size, m, h, w)) y_l = nn.Variable((batch_size, 1)) pred = cnn_model_003(ctx, x_l) loss_ce = ce_loss(ctx, pred, y_l) loss_er = er_loss(ctx, pred) loss_supervised = loss_ce + loss_er ## stochastic regularization x_u0 = nn.Variable((batch_size, m, h, w)) x_u1 = nn.Variable((batch_size, m, h, w)) pred_x_u0 = cnn_model_003(ctx, x_u0) pred_x_u1 = cnn_model_003(ctx, x_u1) loss_sr = sr_loss(ctx, pred_x_u0, pred_x_u1) loss_er0 = er_loss(ctx, pred_x_u0) loss_er1 = er_loss(ctx, pred_x_u1) loss_unsupervised = loss_sr + loss_er0 + loss_er1 ## evaluate batch_size_eval, m, h, w = batch_size, 3, 32, 32 x_eval = nn.Variable((batch_size_eval, m, h, w)) pred_eval = cnn_model_003(ctx, x_eval, test=True) # Solver with nn.context_scope(ctx): solver_superrvised = S.Adam(alpha=learning_rate) solver_superrvised.set_parameters(nn.get_parameters()) solver_unsuperrvised = S.Adam(alpha=learning_rate) solver_unsuperrvised.set_parameters(nn.get_parameters()) # Gradient Scale Container gsc = GradScaleContainer(len(nn.get_parameters())) # Dataset ## separate dataset home = os.environ.get("HOME") fpath = os.path.join(home, "datasets/cifar10/cifar-10.npz") separator = Separator(n_l_train_data) separator.separate_then_save(fpath) l_train_path = os.path.join(home, "datasets/cifar10/l_cifar-10.npz") u_train_path = os.path.join(home, "datasets/cifar10/cifar-10.npz") test_path = os.path.join(home, "datasets/cifar10/cifar-10.npz") # data reader data_reader = Cifar10DataReader( l_train_path, u_train_path, test_path, batch_size=batch_size, n_cls=n_cls, da=True, #TODO: use F.image_augmentation shape=True) # Training loop print("# Training loop") epoch = 1 st = time.time() acc_prev = 0. for i in range(n_iter): # Get data and set it to the varaibles x_l0_data, x_l1_data, y_l_data = data_reader.get_l_train_batch() x_u0_data, x_u1_data, y_u_data = data_reader.get_u_train_batch() x_l.d, _, y_l.d = x_l0_data, x_l1_data, y_l_data x_u0.d, x_u1.d = x_u0_data, x_u1_data # Train loss_supervised.forward(clear_no_need_grad=True) loss_unsupervised.forward(clear_no_need_grad=True) ## compute scales and update with grads of supervised loss solver_superrvised.zero_grad() loss_supervised.backward(clear_buffer=True) gsc.set_scales_supervised_loss(nn.get_parameters()) solver_superrvised.update() ## compute scales and update with grads of unsupervised loss solver_unsuperrvised.zero_grad() loss_unsupervised.backward(clear_buffer=True) gsc.set_scales_unsupervised_loss(nn.get_parameters()) gsc.scale_grad(ctx, nn.get_parameters()) solver_unsuperrvised.update() # Evaluate if (i + 1) % iter_epoch == 0: # Get data and set it to the varaibles x_data, y_data = data_reader.get_test_batch() # Evaluation loop ve = 0. iter_val = 0 for k in range(0, len(x_data), batch_size_eval): x_eval.d = x_data[k:k + batch_size_eval, :] label = y_data[k:k + batch_size_eval, :] pred_eval.forward(clear_buffer=True) ve += categorical_error(pred_eval.d, label) iter_val += 1 msg = "Epoch:{},ElapsedTime:{},Acc:{:02f}".format( epoch, time.time() - st, (1. - ve / iter_val) * 100) print(msg) st = time.time() epoch += 1
def train(args): # Context ctx = get_extension_context(args.context, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) # Args latent = args.latent maps = args.maps batch_size = args.batch_size image_size = args.image_size lambda_ = args.lambda_ # Model # generator loss z = nn.Variable([batch_size, latent]) x_fake = generator(z, maps=maps, up=args.up).apply(persistent=True) p_fake = discriminator(x_fake, maps=maps) loss_gen = gan_loss(p_fake).apply(persistent=True) # discriminator loss p_fake = discriminator(x_fake, maps=maps) x_real = nn.Variable([batch_size, 3, image_size, image_size]) p_real = discriminator(x_real, maps=maps) loss_dis = gan_loss(p_fake, p_real).apply(persistent=True) # gradient penalty eps = F.rand(shape=[batch_size, 1, 1, 1]) x_rmix = eps * x_real + (1.0 - eps) * x_fake p_rmix = discriminator(x_rmix, maps=maps) x_rmix.need_grad = True # Enabling gradient computation for double backward grads = nn.grad([p_rmix], [x_rmix]) l2norms = [F.sum(g**2.0, [1, 2, 3])**0.5 for g in grads] gp = sum([F.mean((l - 1.0)**2.0) for l in l2norms]) loss_dis += lambda_ * gp # generator with fixed value for test z_test = nn.Variable.from_numpy_array(np.random.randn(batch_size, latent)) x_test = generator(z_test, maps=maps, test=True, up=args.up).apply(persistent=True) # Solver solver_gen = S.Adam(args.lrg, args.beta1, args.beta2) solver_dis = S.Adam(args.lrd, args.beta1, args.beta2) with nn.parameter_scope("generator"): params_gen = nn.get_parameters() solver_gen.set_parameters(params_gen) with nn.parameter_scope("discriminator"): params_dis = nn.get_parameters() solver_dis.set_parameters(params_dis) # Monitor monitor = Monitor(args.monitor_path) monitor_loss_gen = MonitorSeries("Generator Loss", monitor, interval=10) monitor_loss_cri = MonitorSeries("Negative Critic Loss", monitor, interval=10) monitor_time = MonitorTimeElapsed("Training Time", monitor, interval=10) monitor_image_tile_train = MonitorImageTile("Image Tile Train", monitor, num_images=batch_size, interval=1, normalize_method=denormalize) monitor_image_tile_test = MonitorImageTile("Image Tile Test", monitor, num_images=batch_size, interval=1, normalize_method=denormalize) # Data Iterator di = data_iterator_cifar10(batch_size, True) # Train loop for i in range(args.max_iter): # Train discriminator x_fake.need_grad = False # no need backward to generator for _ in range(args.n_critic): solver_dis.zero_grad() x_real.d = di.next()[0] / 127.5 - 1.0 z.d = np.random.randn(batch_size, latent) loss_dis.forward(clear_no_need_grad=True) loss_dis.backward(clear_buffer=True) solver_dis.update() # Train generator x_fake.need_grad = True # need backward to generator solver_gen.zero_grad() z.d = np.random.randn(batch_size, latent) loss_gen.forward(clear_no_need_grad=True) loss_gen.backward(clear_buffer=True) solver_gen.update() # Monitor monitor_loss_gen.add(i, loss_gen.d) monitor_loss_cri.add(i, -loss_dis.d) monitor_time.add(i) # Save if i % args.save_interval == 0: monitor_image_tile_train.add(i, x_fake) monitor_image_tile_test.add(i, x_test) nn.save_parameters( os.path.join(args.monitor_path, "params_{}.h5".format(i))) # Last x_test.forward(clear_buffer=True) nn.save_parameters( os.path.join(args.monitor_path, "params_{}.h5".format(i))) monitor_image_tile_train.add(i, x_fake) monitor_image_tile_test.add(i, x_test)
def train(): args = get_args() # Get context. from nnabla.contrib.context import extension_context extension_module = args.context if args.context is None: extension_module = 'cpu' logger.info("Running in %s" % extension_module) ctx = extension_context(extension_module, device_id=args.device_id) nn.set_default_context(ctx) # Create CNN network for both training and testing. if args.net == "cifar10_resnet23_prediction": model_prediction = cifar10_resnet23_prediction # TRAIN maps = 64 data_iterator = data_iterator_cifar10 c = 3 h = w = 32 n_train = 50000 n_valid = 10000 # Create input variables. image = nn.Variable([args.batch_size, c, h, w]) label = nn.Variable([args.batch_size, 1]) # Create model_prediction graph. pred = model_prediction(image, maps=maps, test=False) pred.persistent = True # Create loss function. loss = F.mean(F.softmax_cross_entropy(pred, label)) # SSL Regularization loss += ssl_regularization(nn.get_parameters(), args.filter_decay, args.channel_decay) # TEST # Create input variables. vimage = nn.Variable([args.batch_size, c, h, w]) vlabel = nn.Variable([args.batch_size, 1]) # Create predition graph. vpred = model_prediction(vimage, maps=maps, test=True) # Create Solver. solver = S.Adam(args.learning_rate) solver.set_parameters(nn.get_parameters()) # Create monitor. from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed monitor = Monitor(args.monitor_path) monitor_loss = MonitorSeries("Training loss", monitor, interval=10) monitor_err = MonitorSeries("Training error", monitor, interval=10) monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100) monitor_verr = MonitorSeries("Test error", monitor, interval=1) # Initialize DataIterator data = data_iterator(args.batch_size, True) vdata = data_iterator(args.batch_size, False) best_ve = 1.0 ve = 1.0 # Training loop. for i in range(args.max_iter): if i % args.val_interval == 0: # Validation ve = 0.0 for j in range(int(n_valid / args.batch_size)): vimage.d, vlabel.d = vdata.next() vpred.forward(clear_buffer=True) ve += categorical_error(vpred.d, vlabel.d) ve /= int(n_valid / args.batch_size) monitor_verr.add(i, ve) if ve < best_ve: nn.save_parameters( os.path.join(args.model_save_path, 'params_%06d.h5' % i)) best_ve = ve # Training forward image.d, label.d = data.next() solver.zero_grad() loss.forward(clear_no_need_grad=True) loss.backward(clear_buffer=True) solver.weight_decay(args.weight_decay) solver.update() e = categorical_error(pred.d, label.d) monitor_loss.add(i, loss.d.copy()) monitor_err.add(i, e) monitor_time.add(i) ve = 0.0 for j in range(int(n_valid / args.batch_size)): vimage.d, vlabel.d = vdata.next() vpred.forward(clear_buffer=True) ve += categorical_error(vpred.d, vlabel.d) ve /= int(n_valid / args.batch_size) monitor_verr.add(i, ve) parameter_file = os.path.join(args.model_save_path, 'params_{:06}.h5'.format(args.max_iter)) nn.save_parameters(parameter_file)
def train(): """ Main script. Steps: * Parse command line arguments. * Specify a context for computation. * Initialize DataIterator for MNIST. * Construct a computation graph for training and validation. * Initialize a solver and set parameter variables to it. * Create monitor instances for saving and displaying training stats. * Training loop * Computate error rate for validation data (periodically) * Get a next minibatch. * Execute forwardprop on the training graph. * Compute training error * Set parameter gradients zero * Execute backprop. * Solver updates parameters by using gradients computed by backprop. """ args = get_args() from numpy.random import seed seed(0) # Get context. from nnabla.contrib.context import extension_context extension_module = args.context if args.context is None: extension_module = 'cpu' logger.info("Running in %s" % extension_module) ctx = extension_context(extension_module, device_id=args.device_id) nn.set_default_context(ctx) # Create CNN network for both training and testing. if args.net == 'lenet': mnist_cnn_prediction = mnist_lenet_prediction elif args.net == 'resnet': mnist_cnn_prediction = mnist_resnet_prediction else: raise ValueError("Unknown network type {}".format(args.net)) # TRAIN # Create input variables. image = nn.Variable([args.batch_size, 1, 28, 28]) label = nn.Variable([args.batch_size, 1]) # Create prediction graph. pred = mnist_cnn_prediction(image, test=False, aug=args.augment_train) pred.persistent = True # Create loss function. loss = F.mean(F.softmax_cross_entropy(pred, label)) # TEST # Create input variables. vimage = nn.Variable([args.batch_size, 1, 28, 28]) vlabel = nn.Variable([args.batch_size, 1]) # Create predition graph. vpred = mnist_cnn_prediction(vimage, test=True, aug=args.augment_test) # Create Solver. solver = S.Adam(args.learning_rate) solver.set_parameters(nn.get_parameters()) # Create monitor. from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed monitor = Monitor(args.monitor_path) monitor_loss = MonitorSeries("Training loss", monitor, interval=10) monitor_err = MonitorSeries("Training error", monitor, interval=10) monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100) monitor_verr = MonitorSeries("Test error", monitor, interval=10) # Initialize DataIterator for MNIST. from numpy.random import RandomState data = data_iterator_mnist(args.batch_size, True, rng=RandomState(1223)) vdata = data_iterator_mnist(args.batch_size, False) # Training loop. for i in range(args.max_iter): if i % args.val_interval == 0: # Validation ve = 0.0 for j in range(args.val_iter): vimage.d, vlabel.d = vdata.next() vpred.forward(clear_buffer=True) ve += categorical_error(vpred.d, vlabel.d) monitor_verr.add(i, ve / args.val_iter) if i % args.model_save_interval == 0: nn.save_parameters( os.path.join(args.model_save_path, 'params_%06d.h5' % i)) # Training forward image.d, label.d = data.next() solver.zero_grad() loss.forward(clear_no_need_grad=True) loss.backward(clear_buffer=True) solver.weight_decay(args.weight_decay) solver.update() e = categorical_error(pred.d, label.d) monitor_loss.add(i, loss.d.copy()) monitor_err.add(i, e) monitor_time.add(i) ve = 0.0 for j in range(args.val_iter): vimage.d, vlabel.d = vdata.next() vpred.forward(clear_buffer=True) ve += categorical_error(vpred.d, vlabel.d) monitor_verr.add(i, ve / args.val_iter) parameter_file = os.path.join( args.model_save_path, '{}_params_{:06}.h5'.format(args.net, args.max_iter)) nn.save_parameters(parameter_file)
def train(): # Check NNabla version if utils.get_nnabla_version_integer() < 11900: raise ValueError( 'Please update the nnabla version to v1.19.0 or latest version since memory efficiency of core engine is improved in v1.19.0' ) parser, args = get_train_args() # Get context. ctx = get_extension_context(args.context, device_id=args.device_id) comm = CommunicatorWrapper(ctx) nn.set_default_context(comm.ctx) ext = import_extension_module(args.context) # Monitors # setting up monitors for logging monitor_path = args.output monitor = Monitor(monitor_path) monitor_best_epoch = MonitorSeries('Best epoch', monitor, interval=1) monitor_traing_loss = MonitorSeries('Training loss', monitor, interval=1) monitor_validation_loss = MonitorSeries('Validation loss', monitor, interval=1) monitor_lr = MonitorSeries('learning rate', monitor, interval=1) monitor_time = MonitorTimeElapsed("training time per iteration", monitor, interval=1) if comm.rank == 0: print("Mixing coef. is {}, i.e., MDL = {}*TD-Loss + FD-Loss".format( args.mcoef, args.mcoef)) if not os.path.isdir(args.output): os.makedirs(args.output) # Initialize DataIterator for MUSDB. train_source, valid_source, args = load_datasources(parser, args) train_iter = data_iterator(train_source, args.batch_size, RandomState(args.seed), with_memory_cache=False, with_file_cache=False) valid_iter = data_iterator(valid_source, 1, RandomState(args.seed), with_memory_cache=False, with_file_cache=False) if comm.n_procs > 1: train_iter = train_iter.slice(rng=None, num_of_slices=comm.n_procs, slice_pos=comm.rank) valid_iter = valid_iter.slice(rng=None, num_of_slices=comm.n_procs, slice_pos=comm.rank) # Calculate maxiter per GPU device. max_iter = int((train_source._size // args.batch_size) // comm.n_procs) weight_decay = args.weight_decay * comm.n_procs print("max_iter", max_iter) # Calculate the statistics (mean and variance) of the dataset scaler_mean, scaler_std = utils.get_statistics(args, train_source) max_bin = utils.bandwidth_to_max_bin(train_source.sample_rate, args.nfft, args.bandwidth) unmix = OpenUnmix_CrossNet(input_mean=scaler_mean, input_scale=scaler_std, nb_channels=args.nb_channels, hidden_size=args.hidden_size, n_fft=args.nfft, n_hop=args.nhop, max_bin=max_bin) # Create input variables. mixture_audio = nn.Variable([args.batch_size] + list(train_source._get_data(0)[0].shape)) target_audio = nn.Variable([args.batch_size] + list(train_source._get_data(0)[1].shape)) vmixture_audio = nn.Variable( [1] + [2, valid_source.sample_rate * args.valid_dur]) vtarget_audio = nn.Variable([1] + [8, valid_source.sample_rate * args.valid_dur]) # create training graph mix_spec, M_hat, pred = unmix(mixture_audio) Y = Spectrogram(*STFT(target_audio, n_fft=unmix.n_fft, n_hop=unmix.n_hop), mono=(unmix.nb_channels == 1)) loss_f = mse_loss(mix_spec, M_hat, Y) loss_t = sdr_loss(mixture_audio, pred, target_audio) loss = args.mcoef * loss_t + loss_f loss.persistent = True # Create Solver and set parameters. solver = S.Adam(args.lr) solver.set_parameters(nn.get_parameters()) # create validation graph vmix_spec, vM_hat, vpred = unmix(vmixture_audio, test=True) vY = Spectrogram(*STFT(vtarget_audio, n_fft=unmix.n_fft, n_hop=unmix.n_hop), mono=(unmix.nb_channels == 1)) vloss_f = mse_loss(vmix_spec, vM_hat, vY) vloss_t = sdr_loss(vmixture_audio, vpred, vtarget_audio) vloss = args.mcoef * vloss_t + vloss_f vloss.persistent = True # Initialize Early Stopping es = utils.EarlyStopping(patience=args.patience) # Initialize LR Scheduler (ReduceLROnPlateau) lr_scheduler = ReduceLROnPlateau(lr=args.lr, factor=args.lr_decay_gamma, patience=args.lr_decay_patience) best_epoch = 0 # Training loop. for epoch in trange(args.epochs): # TRAINING losses = utils.AverageMeter() for batch in range(max_iter): mixture_audio.d, target_audio.d = train_iter.next() solver.zero_grad() loss.forward(clear_no_need_grad=True) if comm.n_procs > 1: all_reduce_callback = comm.get_all_reduce_callback() loss.backward(clear_buffer=True, communicator_callbacks=all_reduce_callback) else: loss.backward(clear_buffer=True) solver.weight_decay(weight_decay) solver.update() losses.update(loss.d.copy(), args.batch_size) training_loss = losses.avg # clear cache memory ext.clear_memory_cache() # VALIDATION vlosses = utils.AverageMeter() for batch in range(int(valid_source._size // comm.n_procs)): x, y = valid_iter.next() dur = int(valid_source.sample_rate * args.valid_dur) sp, cnt = 0, 0 loss_tmp = nn.NdArray() loss_tmp.zero() while 1: vmixture_audio.d = x[Ellipsis, sp:sp + dur] vtarget_audio.d = y[Ellipsis, sp:sp + dur] vloss.forward(clear_no_need_grad=True) cnt += 1 sp += dur loss_tmp += vloss.data if x[Ellipsis, sp:sp + dur].shape[-1] < dur or x.shape[-1] == cnt * dur: break loss_tmp = loss_tmp / cnt if comm.n_procs > 1: comm.all_reduce(loss_tmp, division=True, inplace=True) vlosses.update(loss_tmp.data.copy(), 1) validation_loss = vlosses.avg # clear cache memory ext.clear_memory_cache() lr = lr_scheduler.update_lr(validation_loss, epoch=epoch) solver.set_learning_rate(lr) stop = es.step(validation_loss) if comm.rank == 0: monitor_best_epoch.add(epoch, best_epoch) monitor_traing_loss.add(epoch, training_loss) monitor_validation_loss.add(epoch, validation_loss) monitor_lr.add(epoch, lr) monitor_time.add(epoch) if validation_loss == es.best: # save best model nn.save_parameters(os.path.join(args.output, 'best_xumx.h5')) best_epoch = epoch if stop: print("Apply Early Stopping") break
def main(args): # Settings device_id = args.device_id batch_size = args.batch_size batch_size_eval = args.batch_size_eval n_l_train_data = 4000 n_train_data = 50000 n_cls = 10 learning_rate = 1. * 1e-3 n_epoch = 300 act = F.relu iter_epoch = n_train_data / batch_size n_iter = int(n_epoch * iter_epoch) extension_module = args.context lambda_ = args.lambda_ # Model ## supervised batch_size, m, h, w = batch_size, 3, 32, 32 ctx = extension_context(extension_module, device_id=device_id) x_l = nn.Variable((batch_size, m, h, w)) y_l = nn.Variable((batch_size, 1)) pred, log_var = cnn_model_003(ctx, x_l) one = F.constant(1., log_var.shape) loss_ce = ce_loss(ctx, pred, y_l) reg_sigma = sigma_regularization(ctx, log_var, one) loss_supervised = loss_ce + er_loss(ctx, pred) + lambda_ * reg_sigma ## stochastic regularization x_u0 = nn.Variable((batch_size, m, h, w)) x_u1 = nn.Variable((batch_size, m, h, w)) pred_x_u0, log_var0 = cnn_model_003(ctx, x_u0) pred_x_u1, log_var1 = cnn_model_003(ctx, x_u1) loss_sr = sr_loss_with_uncertainty(ctx, pred_x_u0, pred_x_u1, log_var0, log_var1) reg_sigma0 = sigma_regularization(ctx, log_var0, one) reg_sigma1 = sigma_regularization(ctx, log_var1, one) reg_sigmas = sigmas_regularization(ctx, log_var0, log_var1) loss_unsupervised = loss_sr + er_loss(ctx, pred_x_u0) + er_loss(ctx, pred_x_u1) \ + lambda_ * (reg_sigma0 + reg_sigma1) + lambda_ * reg_sigmas ## evaluate batch_size_eval, m, h, w = batch_size, 3, 32, 32 x_eval = nn.Variable((batch_size_eval, m, h, w)) pred_eval, _ = cnn_model_003(ctx, x_eval, test=True) # Solver with nn.context_scope(ctx): solver = S.Adam(alpha=learning_rate) solver.set_parameters(nn.get_parameters()) # Dataset ## separate dataset home = os.environ.get("HOME") fpath = os.path.join(home, "datasets/cifar10/cifar-10.npz") separator = Separator(n_l_train_data) #separator.separate_then_save(fpath) l_train_path = os.path.join(home, "datasets/cifar10/l_cifar-10.npz") u_train_path = os.path.join(home, "datasets/cifar10/cifar-10.npz") test_path = os.path.join(home, "datasets/cifar10/cifar-10.npz") # data reader data_reader = Cifar10DataReader(l_train_path, u_train_path, test_path, batch_size=batch_size, n_cls=n_cls, da=True, shape=True) # Training loop print("# Training loop") epoch = 1 st = time.time() acc_prev = 0. ve_best = 1. save_path_prev = "" for i in range(n_iter): # Get data and set it to the varaibles x_l0_data, x_l1_data, y_l_data = data_reader.get_l_train_batch() x_u0_data, x_u1_data, y_u_data = data_reader.get_u_train_batch() x_l.d, _, y_l.d = x_l0_data, x_l1_data, y_l_data x_u0.d, x_u1.d = x_u0_data, x_u1_data # Train loss_supervised.forward(clear_no_need_grad=True) loss_unsupervised.forward(clear_no_need_grad=True) solver.zero_grad() loss_supervised.backward(clear_buffer=True) loss_unsupervised.backward(clear_buffer=True) solver.update() # Evaluate if int((i + 1) % iter_epoch) == 0: # Get data and set it to the varaibles #x_data, y_data = data_reader.get_test_batch() u_train_data = data_reader.u_train_data x_data, y_data = u_train_data["train_x"].reshape( -1, 3, 32, 32) / 255.0, u_train_data["train_y"], # Evaluation loop ve = 0. iter_val = 0 for k in range(0, len(x_data), batch_size_eval): x_eval.d = get_test_data(x_data, k, batch_size_eval) label = get_test_data(y_data, k, batch_size_eval) pred_eval.forward(clear_buffer=True) ve += categorical_error(pred_eval.d, label) iter_val += 1 ve /= iter_val msg = "Epoch:{},ElapsedTime:{},Acc:{:02f}".format( epoch, time.time() - st, (1. - ve) * 100) print(msg) if ve < ve_best: ve_best = ve st = time.time() epoch += 1
def train(args): # Settings b, c, h, w = 1, 3, 256, 256 beta1 = 0.5 beta2 = 0.999 pool_size = 50 lambda_recon = args.lambda_recon lambda_idt = args.lambda_idt base_lr = args.learning_rate init_method = args.init_method # Context extension_module = args.context if args.context is None: extension_module = 'cpu' logger.info("Running in %s" % extension_module) ctx = get_extension_context(extension_module, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) # Inputs x_raw = nn.Variable([b, c, h, w], need_grad=False) y_raw = nn.Variable([b, c, h, w], need_grad=False) x_real = image_augmentation(x_raw) y_real = image_augmentation(y_raw) x_history = nn.Variable([b, c, h, w]) y_history = nn.Variable([b, c, h, w]) x_real_test = nn.Variable([b, c, h, w], need_grad=False) y_real_test = nn.Variable([b, c, h, w], need_grad=False) # Models for training # Generate y_fake = models.g(x_real, unpool=args.unpool, init_method=init_method) x_fake = models.f(y_real, unpool=args.unpool, init_method=init_method) y_fake.persistent, x_fake.persistent = True, True # Reconstruct x_recon = models.f(y_fake, unpool=args.unpool, init_method=init_method) y_recon = models.g(x_fake, unpool=args.unpool, init_method=init_method) # Discriminate d_y_fake = models.d_y(y_fake, init_method=init_method) d_x_fake = models.d_x(x_fake, init_method=init_method) d_y_real = models.d_y(y_real, init_method=init_method) d_x_real = models.d_x(x_real, init_method=init_method) d_y_history = models.d_y(y_history, init_method=init_method) d_x_history = models.d_x(x_history, init_method=init_method) # Models for test y_fake_test = models.g( x_real_test, unpool=args.unpool, init_method=init_method) x_fake_test = models.f( y_real_test, unpool=args.unpool, init_method=init_method) y_fake_test.persistent, x_fake_test.persistent = True, True # Reconstruct x_recon_test = models.f( y_fake_test, unpool=args.unpool, init_method=init_method) y_recon_test = models.g( x_fake_test, unpool=args.unpool, init_method=init_method) # Losses # Reconstruction Loss loss_recon = models.recon_loss(x_recon, x_real) \ + models.recon_loss(y_recon, y_real) # Generator loss loss_gen = models.lsgan_loss(d_y_fake) \ + models.lsgan_loss(d_x_fake) \ + lambda_recon * loss_recon # Identity loss if lambda_idt != 0: logger.info("Identity loss was added.") # Identity y_idt = models.g(y_real, unpool=args.unpool, init_method=init_method) x_idt = models.f(x_real, unpool=args.unpool, init_method=init_method) loss_idt = models.recon_loss(x_idt, x_real) \ + models.recon_loss(y_idt, y_real) loss_gen += lambda_recon * lambda_idt * loss_idt # Discriminator losses loss_dis_y = models.lsgan_loss(d_y_history, d_y_real) loss_dis_x = models.lsgan_loss(d_x_history, d_x_real) # Solvers solver_gen = S.Adam(base_lr, beta1, beta2) solver_dis_x = S.Adam(base_lr, beta1, beta2) solver_dis_y = S.Adam(base_lr, beta1, beta2) with nn.parameter_scope('generator'): solver_gen.set_parameters(nn.get_parameters()) with nn.parameter_scope('discriminator'): with nn.parameter_scope("x"): solver_dis_x.set_parameters(nn.get_parameters()) with nn.parameter_scope("y"): solver_dis_y.set_parameters(nn.get_parameters()) # Datasets rng = np.random.RandomState(313) ds_train_B = cycle_gan_data_source( args.dataset, train=True, domain="B", shuffle=True, rng=rng) ds_train_A = cycle_gan_data_source( args.dataset, train=True, domain="A", shuffle=True, rng=rng) ds_test_B = cycle_gan_data_source( args.dataset, train=False, domain="B", shuffle=False, rng=rng) ds_test_A = cycle_gan_data_source( args.dataset, train=False, domain="A", shuffle=False, rng=rng) di_train_B = cycle_gan_data_iterator(ds_train_B, args.batch_size) di_train_A = cycle_gan_data_iterator(ds_train_A, args.batch_size) di_test_B = cycle_gan_data_iterator(ds_test_B, args.batch_size) di_test_A = cycle_gan_data_iterator(ds_test_A, args.batch_size) # Monitors monitor = Monitor(args.monitor_path) def make_monitor(name): return MonitorSeries(name, monitor, interval=1) monitor_loss_gen = make_monitor('generator_loss') monitor_loss_dis_x = make_monitor('discriminator_B_domain_loss') monitor_loss_dis_y = make_monitor('discriminator_A_domain_loss') def make_monitor_image(name): return MonitorImage(name, monitor, interval=1, normalize_method=lambda x: (x + 1.0) * 127.5) monitor_train_gx = make_monitor_image('fake_images_train_A') monitor_train_fy = make_monitor_image('fake_images_train_B') monitor_train_x_recon = make_monitor_image('fake_images_B_recon_train') monitor_train_y_recon = make_monitor_image('fake_images_A_recon_train') monitor_test_gx = make_monitor_image('fake_images_test_A') monitor_test_fy = make_monitor_image('fake_images_test_B') monitor_test_x_recon = make_monitor_image('fake_images_recon_test_B') monitor_test_y_recon = make_monitor_image('fake_images_recon_test_A') monitor_train_list = [ (monitor_train_gx, y_fake), (monitor_train_fy, x_fake), (monitor_train_x_recon, x_recon), (monitor_train_y_recon, y_recon), (monitor_loss_gen, loss_gen), (monitor_loss_dis_x, loss_dis_x), (monitor_loss_dis_y, loss_dis_y), ] monitor_test_list = [ (monitor_test_gx, y_fake_test), (monitor_test_fy, x_fake_test), (monitor_test_x_recon, x_recon_test), (monitor_test_y_recon, y_recon_test)] # ImagePool pool_x = ImagePool(pool_size) pool_y = ImagePool(pool_size) # Training loop epoch = 0 n_images = np.max([ds_train_B.size, ds_train_A.size] ) # num. images for each domain max_iter = args.max_epoch * n_images // args.batch_size for i in range(max_iter): # Validation if int((i+1) % (n_images // args.batch_size)) == 0: logger.info("Mode:Test,Epoch:{}".format(epoch)) # Monitor for train for monitor, v in monitor_train_list: monitor.add(i, v.d) # Use training graph since there are no test mode x_data, _ = di_test_B.next() y_data, _ = di_test_A.next() x_real_test.d = x_data y_real_test.d = y_data x_recon_test.forward() y_recon_test.forward() # Monitor for test for monitor, v in monitor_test_list: monitor.add(i, v.d) # Save model nn.save_parameters(os.path.join( args.model_save_path, 'params_%06d.h5' % i)) # Learning rate decay for solver in [solver_gen, solver_dis_x, solver_dis_y]: linear_decay(solver, base_lr, epoch, args.max_epoch) epoch += 1 # Get data x_data, _ = di_train_B.next() y_data, _ = di_train_A.next() x_raw.d = x_data y_raw.d = y_data # Train Generators loss_gen.forward(clear_no_need_grad=False) solver_gen.zero_grad() loss_gen.backward(clear_buffer=True) solver_gen.update() # Insert and Get to/from pool x_history.d = pool_x.insert_then_get(x_fake.d) y_history.d = pool_y.insert_then_get(y_fake.d) # Train Discriminator Y loss_dis_y.forward(clear_no_need_grad=False) solver_dis_y.zero_grad() loss_dis_y.backward(clear_buffer=True) solver_dis_y.update() # Train Discriminator X loss_dis_x.forward(clear_no_need_grad=False) solver_dis_x.zero_grad() loss_dis_x.backward(clear_buffer=True) solver_dis_x.update()
output = F.relu(PF.affine(m, output_mlp_size)) if train: output = F.dropout(output, p=dropout_ratio) with nn.parameter_scope('output'): y = F.sigmoid(PF.affine(output, 1)) accuracy = F.mean(F.equal(F.round(y), t)) loss = F.mean(F.binary_cross_entropy( y, t)) + attention_penalty_coef * frobenius( F.batch_matmul(a, a, transpose_a=True) - batch_eye(batch_size, r)) return x, t, accuracy, loss # Create solver. x, t, accuracy, loss = build_self_attention_model(train=True) solver = S.Adam() solver.set_parameters(nn.get_parameters()) x, t, accuracy, loss = build_self_attention_model(train=True) trainer = Trainer(inputs=[x, t], loss=loss, metrics={ 'cross entropy': loss, 'accuracy': accuracy }, solver=solver) for epoch in range(max_epoch): x, t, accuracy, loss = build_self_attention_model(train=True) trainer.update_variables(inputs=[x, t], loss=loss, metrics={
loss_gen = F.mean( F.sigmoid_cross_entropy(pred_fake, F.constant(1, pred_fake.shape))) fake_dis = fake.get_unlinked_variable(need_grad=True) fake_dis.need_grad = True # TODO: Workaround until v1.0.2 pred_fake_dis = I.discriminator(fake_dis) loss_dis = F.mean( F.sigmoid_cross_entropy(pred_fake_dis, F.constant(0, pred_fake_dis.shape))) # Realパスの設定 x = nn.Variable([batch_size, 1, 28, 28]) pred_real = I.discriminator(x) loss_dis += F.mean( F.sigmoid_cross_entropy(pred_real, F.constant(1, pred_real.shape))) # ソルバーの作成 solver_gen = S.Adam(learning_rate, beta1=0.5) solver_dis = S.Adam(learning_rate, beta1=0.5) with nn.parameter_scope("gen"): solver_gen.set_parameters(nn.get_parameters()) with nn.parameter_scope("dis"): solver_dis.set_parameters(nn.get_parameters()) # パラメータスコープの使い方を見ておく。 print(len(nn.get_parameters())) with nn.parameter_scope("gen"): print(len(nn.get_parameters())) # パラメータスコープ内では、`get_parameters()`で取得できるパラメータがフィルタリングされ # る。 # モニターの設定 path = cache_dir(os.path.join(I.name, "monitor"))