def train(): ''' Main script. ''' args = get_args() from numpy.random import seed seed(0) # Get context. from nnabla.ext_utils import get_extension_context logger.info("Running in %s" % args.context) ctx = get_extension_context(args.context, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) # TRAIN image = nn.Variable([args.batch_size, 1, 28, 28]) label = nn.Variable([args.batch_size, 1]) x = image / 255.0 t_onehot = F.one_hot(label, (10, )) with nn.parameter_scope("capsnet"): c1, pcaps, u_hat, caps, pred = model.capsule_net( x, test=False, aug=True, grad_dynamic_routing=args.grad_dynamic_routing) with nn.parameter_scope("capsnet_reconst"): recon = model.capsule_reconstruction(caps, t_onehot) loss_margin, loss_reconst, loss = model.capsule_loss( pred, t_onehot, recon, x) pred.persistent = True # TEST # Create input variables. vimage = nn.Variable([args.batch_size, 1, 28, 28]) vlabel = nn.Variable([args.batch_size, 1]) vx = vimage / 255.0 with nn.parameter_scope("capsnet"): _, _, _, _, vpred = model.capsule_net(vx, test=True, aug=False) # Create Solver. solver = S.Adam(args.learning_rate) solver.set_parameters(nn.get_parameters()) # Create monitor. from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed train_iter = int(60000 / args.batch_size) val_iter = int(10000 / args.batch_size) logger.info("#Train: {} #Validation: {}".format(train_iter, val_iter)) monitor = Monitor(args.monitor_path) monitor_loss = MonitorSeries("Training loss", monitor, interval=1) monitor_mloss = MonitorSeries("Training margin loss", monitor, interval=1) monitor_rloss = MonitorSeries("Training reconstruction loss", monitor, interval=1) monitor_err = MonitorSeries("Training error", monitor, interval=1) monitor_time = MonitorTimeElapsed("Training time", monitor, interval=1) monitor_verr = MonitorSeries("Test error", monitor, interval=1) monitor_lr = MonitorSeries("Learning rate", monitor, interval=1) # To_save_nnp m_image, m_label, m_noise, m_recon = model_tweak_digitscaps( args.batch_size) contents = save_nnp({ 'x1': m_image, 'x2': m_label, 'x3': m_noise }, {'y': m_recon}, args.batch_size) save.save(os.path.join(args.monitor_path, 'capsnet_epoch0_result.nnp'), contents) # Initialize DataIterator for MNIST. from numpy.random import RandomState data = data_iterator_mnist(args.batch_size, True, rng=RandomState(1223)) vdata = data_iterator_mnist(args.batch_size, False) start_point = 0 if args.checkpoint is not None: # load weights and solver state info from specified checkpoint file. start_point = load_checkpoint(args.checkpoint, solver) # Training loop. for e in range(start_point, args.max_epochs): # Learning rate decay learning_rate = solver.learning_rate() if e != 0: learning_rate *= 0.9 solver.set_learning_rate(learning_rate) monitor_lr.add(e, learning_rate) # Training train_error = 0.0 train_loss = 0.0 train_mloss = 0.0 train_rloss = 0.0 for i in range(train_iter): image.d, label.d = data.next() solver.zero_grad() loss.forward(clear_no_need_grad=True) loss.backward(clear_buffer=True) solver.update() train_error += categorical_error(pred.d, label.d) train_loss += loss.d train_mloss += loss_margin.d train_rloss += loss_reconst.d train_error /= train_iter train_loss /= train_iter train_mloss /= train_iter train_rloss /= train_iter # Validation val_error = 0.0 for j in range(val_iter): vimage.d, vlabel.d = vdata.next() vpred.forward(clear_buffer=True) val_error += categorical_error(vpred.d, vlabel.d) val_error /= val_iter # Monitor monitor_time.add(e) monitor_loss.add(e, train_loss) monitor_mloss.add(e, train_mloss) monitor_rloss.add(e, train_rloss) monitor_err.add(e, train_error) monitor_verr.add(e, val_error) save_checkpoint(args.monitor_path, e, solver) # To_save_nnp contents = save_nnp({ 'x1': m_image, 'x2': m_label, 'x3': m_noise }, {'y': m_recon}, args.batch_size) save.save(os.path.join(args.monitor_path, 'capsnet_result.nnp'), contents)
def train(): ''' Main script. ''' args = get_args() from numpy.random import seed seed(0) # Get context. from nnabla.contrib.context import extension_context extension_module = args.context if args.context is None: extension_module = 'cpu' logger.info("Running in %s" % extension_module) ctx = extension_context(extension_module, device_id=args.device_id) nn.set_default_context(ctx) # TRAIN image = nn.Variable([args.batch_size, 1, 28, 28]) label = nn.Variable([args.batch_size, 1]) x = image / 255.0 t_onehot = F.one_hot(label, (10, )) with nn.parameter_scope("capsnet"): c1, pcaps, u_hat, caps, pred = model.capsule_net( x, test=False, aug=True, grad_dynamic_routing=args.grad_dynamic_routing) with nn.parameter_scope("capsnet_reconst"): recon = model.capsule_reconstruction(caps, t_onehot) loss_margin, loss_reconst, loss = model.capsule_loss( pred, t_onehot, recon, x) pred.persistent = True # TEST # Create input variables. vimage = nn.Variable([args.batch_size, 1, 28, 28]) vlabel = nn.Variable([args.batch_size, 1]) vx = vimage / 255.0 with nn.parameter_scope("capsnet"): _, _, _, _, vpred = model.capsule_net(vx, test=True, aug=False) # Create Solver. solver = S.Adam(args.learning_rate) solver.set_parameters(nn.get_parameters()) # Create monitor. from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed train_iter = int(60000 / args.batch_size) val_iter = int(10000 / args.batch_size) logger.info("#Train: {} #Validation: {}".format(train_iter, val_iter)) monitor = Monitor(args.monitor_path) monitor_loss = MonitorSeries("Training loss", monitor, interval=1) monitor_mloss = MonitorSeries("Training margin loss", monitor, interval=1) monitor_rloss = MonitorSeries("Training reconstruction loss", monitor, interval=1) monitor_err = MonitorSeries("Training error", monitor, interval=1) monitor_time = MonitorTimeElapsed("Training time", monitor, interval=1) monitor_verr = MonitorSeries("Test error", monitor, interval=1) monitor_lr = MonitorSeries("Learning rate", monitor, interval=1) # Initialize DataIterator for MNIST. from numpy.random import RandomState data = data_iterator_mnist(args.batch_size, True, rng=RandomState(1223)) vdata = data_iterator_mnist(args.batch_size, False) # Training loop. for e in range(args.max_epochs): # Learning rate decay learning_rate = solver.learning_rate() if e != 0: learning_rate *= 0.9 solver.set_learning_rate(learning_rate) monitor_lr.add(e, learning_rate) # Training train_error = 0.0 train_loss = 0.0 train_mloss = 0.0 train_rloss = 0.0 for i in range(train_iter): image.d, label.d = data.next() solver.zero_grad() loss.forward(clear_no_need_grad=True) loss.backward(clear_buffer=True) solver.update() train_error += categorical_error(pred.d, label.d) train_loss += loss.d train_mloss += loss_margin.d train_rloss += loss_reconst.d train_error /= train_iter train_loss /= train_iter train_mloss /= train_iter train_rloss /= train_iter # Validation val_error = 0.0 for j in range(val_iter): vimage.d, vlabel.d = vdata.next() vpred.forward(clear_buffer=True) val_error += categorical_error(vpred.d, vlabel.d) val_error /= val_iter # Monitor monitor_time.add(e) monitor_loss.add(e, train_loss) monitor_mloss.add(e, train_mloss) monitor_rloss.add(e, train_rloss) monitor_err.add(e, train_error) monitor_verr.add(e, val_error) nn.save_parameters( os.path.join(args.monitor_path, 'params_%06d.h5' % e))