def train(): """ Main script. Steps: * Parse command line arguments. * Specify a context for computation. * Initialize DataIterator for MNIST. * Construct a computation graph for training and validation. * Initialize a solver and set parameter variables to it. * Create monitor instances for saving and displaying training stats. * Training loop * Computate error rate for validation data (periodically) * Get a next minibatch. * Execute forwardprop on the training graph. * Compute training error * Set parameter gradients zero * Execute backprop. * Solver updates parameters by using gradients computed by backprop. """ args = get_args() from numpy.random import seed seed(0) # Get context. from nnabla.ext_utils import get_extension_context logger.info("Running in %s" % args.context) ctx = get_extension_context(args.context, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) # Create CNN network for both training and testing. if args.net == 'lenet': mnist_cnn_prediction = mnist_lenet_prediction elif args.net == 'resnet': mnist_cnn_prediction = mnist_resnet_prediction else: raise ValueError("Unknown network type {}".format(args.net)) # TRAIN # Create input variables. image = nn.Variable([args.batch_size, 1, 28, 28]) label = nn.Variable([args.batch_size, 1]) # Create prediction graph. pred = mnist_cnn_prediction(image, test=False, aug=args.augment_train) pred.persistent = True # Create loss function. loss = F.mean(F.softmax_cross_entropy(pred, label)) # TEST # Create input variables. vimage = nn.Variable([args.batch_size, 1, 28, 28]) vlabel = nn.Variable([args.batch_size, 1]) # Create prediction graph. vpred = mnist_cnn_prediction(vimage, test=True, aug=args.augment_test) # Create Solver. If training from checkpoint, load the info. solver = S.Adam(args.learning_rate) solver.set_parameters(nn.get_parameters()) start_point = 0 if args.checkpoint is not None: # load weights and solver state info from specified checkpoint file. start_point = load_checkpoint(args.checkpoint, solver) # Create monitor. from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed monitor = Monitor(args.monitor_path) monitor_loss = MonitorSeries("Training loss", monitor, interval=10) monitor_err = MonitorSeries("Training error", monitor, interval=10) monitor_time = MonitorTimeElapsed("Training time", monitor, interval=100) monitor_verr = MonitorSeries("Test error", monitor, interval=10) # save_nnp contents = save_nnp({'x': vimage}, {'y': vpred}, args.batch_size) save.save( os.path.join(args.model_save_path, '{}_result_epoch0.nnp'.format(args.net)), contents) # Initialize DataIterator for MNIST. from numpy.random import RandomState data = data_iterator_mnist(args.batch_size, True, rng=RandomState(1223)) vdata = data_iterator_mnist(args.batch_size, False) # Training loop. for i in range(start_point, args.max_iter): if i % args.val_interval == 0: # Validation ve = 0.0 for j in range(args.val_iter): vimage.d, vlabel.d = vdata.next() vpred.forward(clear_buffer=True) vpred.data.cast(np.float32, ctx) ve += categorical_error(vpred.d, vlabel.d) monitor_verr.add(i, ve / args.val_iter) if i % args.model_save_interval == 0: # save checkpoint file save_checkpoint(args.model_save_path, i, solver) # Training forward image.d, label.d = data.next() solver.zero_grad() loss.forward(clear_no_need_grad=True) loss.backward(clear_buffer=True) solver.weight_decay(args.weight_decay) solver.update() loss.data.cast(np.float32, ctx) pred.data.cast(np.float32, ctx) e = categorical_error(pred.d, label.d) monitor_loss.add(i, loss.d.copy()) monitor_err.add(i, e) monitor_time.add(i) ve = 0.0 for j in range(args.val_iter): vimage.d, vlabel.d = vdata.next() vpred.forward(clear_buffer=True) ve += categorical_error(vpred.d, vlabel.d) monitor_verr.add(i, ve / args.val_iter) parameter_file = os.path.join( args.model_save_path, '{}_params_{:06}.h5'.format(args.net, args.max_iter)) nn.save_parameters(parameter_file) # save_nnp_lastepoch contents = save_nnp({'x': vimage}, {'y': vpred}, args.batch_size) save.save( os.path.join(args.model_save_path, '{}_result.nnp'.format(args.net)), contents)
def train(): """ Main script. """ args = get_args() _ = nn.load_parameters(args.pretrained_model_path) if args.fine_tune: nnabla.parameter.pop_parameter('decoder/logits/affine/conv/W') nnabla.parameter.pop_parameter('decoder/logits/affine/conv/b') n_train_samples = args.train_samples n_val_samples = args.val_samples distributed = args.distributed compute_acc = args.compute_acc if distributed: # Communicator and Context from nnabla.ext_utils import get_extension_context extension_module = "cudnn" ctx = get_extension_context( extension_module, type_config=args.type_config) comm = C.MultiProcessDataParalellCommunicator(ctx) comm.init() n_devices = comm.size mpi_rank = comm.rank device_id = mpi_rank ctx.device_id = str(device_id) nn.set_default_context(ctx) else: # Get context. from nnabla.ext_utils import get_extension_context extension_module = args.context if args.context is None: extension_module = 'cpu' logger.info("Running in %s" % extension_module) ctx = get_extension_context( extension_module, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) n_devices = 1 device_id = 0 # training data data = data_iterator_segmentation( args.train_samples, args.batch_size, args.train_dir, args.train_label_dir, target_width=args.image_width, target_height=args.image_height) # validation data vdata = data_iterator_segmentation(args.val_samples, args.batch_size, args.val_dir, args.val_label_dir, target_width=args.image_width, target_height=args.image_height) if distributed: data = data.slice( rng=None, num_of_slices=n_devices, slice_pos=device_id) vdata = vdata.slice( rng=None, num_of_slices=n_devices, slice_pos=device_id) num_classes = args.num_class # Workaround to start with the same initialized weights for all workers. np.random.seed(313) t_model = get_model( args, test=False) t_model.pred.persistent = True # Not clearing buffer of pred in backward t_pred2 = t_model.pred.unlinked() t_e = F.sum(F.top_n_error(t_pred2, t_model.label, axis=1) * t_model.mask) / F.sum(t_model.mask) v_model = get_model( args, test=True) v_model.pred.persistent = True # Not clearing buffer of pred in forward v_pred2 = v_model.pred.unlinked() v_e = F.sum(F.top_n_error(v_pred2, v_model.label, axis=1) * v_model.mask) / F.sum(t_model.mask) # Create Solver solver = S.Momentum(args.learning_rate, 0.9) solver.set_parameters(nn.get_parameters()) # Load checkpoint start_point = 0 if args.checkpoint is not None: # load weights and solver state info from specified checkpoint file. start_point = load_checkpoint(args.checkpoint, solver) # Setting warmup. base_lr = args.learning_rate / n_devices warmup_iter = int(1. * n_train_samples / args.batch_size / args.accum_grad / n_devices) * args.warmup_epoch warmup_slope = base_lr * (n_devices - 1) / warmup_iter solver.set_learning_rate(base_lr) # Create monitor import nnabla.monitor as M monitor = M.Monitor(args.monitor_path) monitor_loss = M.MonitorSeries("Training loss", monitor, interval=10) monitor_err = M.MonitorSeries("Training error", monitor, interval=10) monitor_vloss = M.MonitorSeries("Validation loss", monitor, interval=1) monitor_verr = M.MonitorSeries("Validation error", monitor, interval=1) monitor_time = M.MonitorTimeElapsed("Training time", monitor, interval=10) monitor_miou = M.MonitorSeries("mean IOU", monitor, interval=10) monitor_vtime = M.MonitorTimeElapsed( "Validation time", monitor, interval=1) # save_nnp contents = save_nnp({'x': v_model.image}, { 'y': v_model.pred}, args.batch_size) save.save(os.path.join(args.model_save_path, 'Deeplabv3plus_result_epoch0.nnp'), contents, variable_batch_size=False) # Training loop for i in range(start_point, int(args.max_iter / n_devices)): # Save parameters if i % (args.model_save_interval // n_devices) == 0 and device_id == 0: save_checkpoint(args.model_save_path, i, solver) # Validation if i % (args.val_interval // n_devices) == 0 and i != 0: vmiou_local = 0. val_iter_local = n_val_samples // args.batch_size vl_local = nn.NdArray() vl_local.zero() ve_local = nn.NdArray() ve_local.zero() for j in range(val_iter_local): images, labels, masks = vdata.next() v_model.image.d = images v_model.label.d = labels v_model.mask.d = masks v_model.image.data.cast(np.float32, ctx) v_model.label.data.cast(np.int32, ctx) v_model.loss.forward(clear_buffer=True) v_e.forward(clear_buffer=True) vl_local += v_model.loss.data ve_local += v_e.data # Mean IOU computation if compute_acc: vmiou_local += compute_miou(num_classes, labels, np.argmax(v_model.pred.d, axis=1), masks) vl_local /= val_iter_local ve_local /= val_iter_local if compute_acc: vmiou_local /= val_iter_local vmiou_ndarray = nn.NdArray.from_numpy_array( np.array(vmiou_local)) if distributed: comm.all_reduce(vl_local, division=True, inplace=True) comm.all_reduce(ve_local, division=True, inplace=True) if compute_acc: comm.all_reduce(vmiou_ndarray, division=True, inplace=True) if device_id == 0: monitor_vloss.add(i * n_devices, vl_local.data.copy()) monitor_verr.add(i * n_devices, ve_local.data.copy()) if compute_acc: monitor_miou.add(i * n_devices, vmiou_local) monitor_vtime.add(i * n_devices) # Training l = 0.0 e = 0.0 solver.zero_grad() e_acc = nn.NdArray(t_e.shape) e_acc.zero() l_acc = nn.NdArray(t_model.loss.shape) l_acc.zero() # Gradient accumulation loop for j in range(args.accum_grad): images, labels, masks = data.next() t_model.image.d = images t_model.label.d = labels t_model.mask.d = masks t_model.image.data.cast(np.float32, ctx) t_model.label.data.cast(np.int32, ctx) t_model.loss.forward(clear_no_need_grad=True) t_model.loss.backward(clear_buffer=True) # Accumulating gradients t_e.forward(clear_buffer=True) e_acc += t_e.data l_acc += t_model.loss.data # AllReduce if distributed: params = [x.grad for x in nn.get_parameters().values()] comm.all_reduce(params, division=False, inplace=False) comm.all_reduce(l_acc, division=True, inplace=True) comm.all_reduce(e_acc, division=True, inplace=True) solver.scale_grad(1./args.accum_grad) solver.weight_decay(args.weight_decay) solver.update() # Linear Warmup if i <= warmup_iter: lr = base_lr + warmup_slope * i solver.set_learning_rate(lr) if distributed: # Synchronize by averaging the weights over devices using allreduce if (i+1) % args.sync_weight_every_itr == 0: weights = [x.data for x in nn.get_parameters().values()] comm.all_reduce(weights, division=True, inplace=True) if device_id == 0: monitor_loss.add( i * n_devices, (l_acc / args.accum_grad).data.copy()) monitor_err.add( i * n_devices, (e_acc / args.accum_grad).data.copy()) monitor_time.add(i * n_devices) # Learning rate decay at scheduled iter --> changed to poly learning rate decay policy # if i in args.learning_rate_decay_at: solver.set_learning_rate(base_lr * ((1 - i / args.max_iter)**0.1)) if device_id == 0: nn.save_parameters(os.path.join(args.model_save_path, 'param_%06d.h5' % args.max_iter)) contents = save_nnp({'x': v_model.image}, { 'y': v_model.pred}, args.batch_size) save.save(os.path.join(args.model_save_path, 'Deeplabv3plus_result.nnp'), contents, variable_batch_size=False)
def train(args): """ Main script. """ # Get context. from nnabla.ext_utils import get_extension_context logger.info("Running in %s" % args.context) ctx = get_extension_context(args.context, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) # Create CNN network for both training and testing. margin = 1.0 # Margin for contrastive loss. # TRAIN # Create input variables. image0 = nn.Variable([args.batch_size, 1, 28, 28]) image1 = nn.Variable([args.batch_size, 1, 28, 28]) label = nn.Variable([args.batch_size]) # Create prediction graph. pred = mnist_lenet_siamese(image0, image1, test=False) # Create loss function. loss = F.mean(contrastive_loss(pred, label, margin)) # TEST # Create input variables. vimage0 = nn.Variable([args.batch_size, 1, 28, 28]) vimage1 = nn.Variable([args.batch_size, 1, 28, 28]) vlabel = nn.Variable([args.batch_size]) # Create prediction graph. vpred = mnist_lenet_siamese(vimage0, vimage1, test=True) vloss = F.mean(contrastive_loss(vpred, vlabel, margin)) # Create Solver. solver = S.Adam(args.learning_rate) solver.set_parameters(nn.get_parameters()) start_point = 0 if args.checkpoint is not None: # load weights and solver state info from specified checkpoint file. start_point = load_checkpoint(args.checkpoint, solver) # Create monitor. import nnabla.monitor as M monitor = M.Monitor(args.monitor_path) monitor_loss = M.MonitorSeries("Training loss", monitor, interval=10) monitor_time = M.MonitorTimeElapsed("Training time", monitor, interval=100) monitor_vloss = M.MonitorSeries("Test loss", monitor, interval=10) # Initialize DataIterator for MNIST. rng = np.random.RandomState(313) data = siamese_data_iterator(args.batch_size, True, rng) vdata = siamese_data_iterator(args.batch_size, False, rng) # Training loop. for i in range(start_point, args.max_iter): if i % args.val_interval == 0: # Validation ve = 0.0 for j in range(args.val_iter): vimage0.d, vimage1.d, vlabel.d = vdata.next() vloss.forward(clear_buffer=True) ve += vloss.d monitor_vloss.add(i, ve / args.val_iter) if i % args.model_save_interval == 0: # save checkpoint file save_checkpoint(args.model_save_path, i, solver) image0.d, image1.d, label.d = data.next() solver.zero_grad() # Training forward, backward and update loss.forward(clear_no_need_grad=True) loss.backward(clear_buffer=True) solver.weight_decay(args.weight_decay) solver.update() monitor_loss.add(i, loss.d.copy()) monitor_time.add(i) parameter_file = os.path.join(args.model_save_path, 'params_%06d.h5' % args.max_iter) nn.save_parameters(parameter_file)
def train(): ''' Main script. ''' args = get_args() from numpy.random import seed seed(0) # Get context. from nnabla.ext_utils import get_extension_context logger.info("Running in %s" % args.context) ctx = get_extension_context(args.context, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) # TRAIN image = nn.Variable([args.batch_size, 1, 28, 28]) label = nn.Variable([args.batch_size, 1]) x = image / 255.0 t_onehot = F.one_hot(label, (10, )) with nn.parameter_scope("capsnet"): c1, pcaps, u_hat, caps, pred = model.capsule_net( x, test=False, aug=True, grad_dynamic_routing=args.grad_dynamic_routing) with nn.parameter_scope("capsnet_reconst"): recon = model.capsule_reconstruction(caps, t_onehot) loss_margin, loss_reconst, loss = model.capsule_loss( pred, t_onehot, recon, x) pred.persistent = True # TEST # Create input variables. vimage = nn.Variable([args.batch_size, 1, 28, 28]) vlabel = nn.Variable([args.batch_size, 1]) vx = vimage / 255.0 with nn.parameter_scope("capsnet"): _, _, _, _, vpred = model.capsule_net(vx, test=True, aug=False) # Create Solver. solver = S.Adam(args.learning_rate) solver.set_parameters(nn.get_parameters()) # Create monitor. from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed train_iter = int(60000 / args.batch_size) val_iter = int(10000 / args.batch_size) logger.info("#Train: {} #Validation: {}".format(train_iter, val_iter)) monitor = Monitor(args.monitor_path) monitor_loss = MonitorSeries("Training loss", monitor, interval=1) monitor_mloss = MonitorSeries("Training margin loss", monitor, interval=1) monitor_rloss = MonitorSeries("Training reconstruction loss", monitor, interval=1) monitor_err = MonitorSeries("Training error", monitor, interval=1) monitor_time = MonitorTimeElapsed("Training time", monitor, interval=1) monitor_verr = MonitorSeries("Test error", monitor, interval=1) monitor_lr = MonitorSeries("Learning rate", monitor, interval=1) # To_save_nnp m_image, m_label, m_noise, m_recon = model_tweak_digitscaps( args.batch_size) contents = save_nnp({ 'x1': m_image, 'x2': m_label, 'x3': m_noise }, {'y': m_recon}, args.batch_size) save.save(os.path.join(args.monitor_path, 'capsnet_epoch0_result.nnp'), contents) # Initialize DataIterator for MNIST. from numpy.random import RandomState data = data_iterator_mnist(args.batch_size, True, rng=RandomState(1223)) vdata = data_iterator_mnist(args.batch_size, False) start_point = 0 if args.checkpoint is not None: # load weights and solver state info from specified checkpoint file. start_point = load_checkpoint(args.checkpoint, solver) # Training loop. for e in range(start_point, args.max_epochs): # Learning rate decay learning_rate = solver.learning_rate() if e != 0: learning_rate *= 0.9 solver.set_learning_rate(learning_rate) monitor_lr.add(e, learning_rate) # Training train_error = 0.0 train_loss = 0.0 train_mloss = 0.0 train_rloss = 0.0 for i in range(train_iter): image.d, label.d = data.next() solver.zero_grad() loss.forward(clear_no_need_grad=True) loss.backward(clear_buffer=True) solver.update() train_error += categorical_error(pred.d, label.d) train_loss += loss.d train_mloss += loss_margin.d train_rloss += loss_reconst.d train_error /= train_iter train_loss /= train_iter train_mloss /= train_iter train_rloss /= train_iter # Validation val_error = 0.0 for j in range(val_iter): vimage.d, vlabel.d = vdata.next() vpred.forward(clear_buffer=True) val_error += categorical_error(vpred.d, vlabel.d) val_error /= val_iter # Monitor monitor_time.add(e) monitor_loss.add(e, train_loss) monitor_mloss.add(e, train_mloss) monitor_rloss.add(e, train_rloss) monitor_err.add(e, train_error) monitor_verr.add(e, val_error) save_checkpoint(args.monitor_path, e, solver) # To_save_nnp contents = save_nnp({ 'x1': m_image, 'x2': m_label, 'x3': m_noise }, {'y': m_recon}, args.batch_size) save.save(os.path.join(args.monitor_path, 'capsnet_result.nnp'), contents)
def train(): """ Main script. """ args = get_args() # Get context. from nnabla.ext_utils import get_extension_context extension_module = args.context if args.context is None: extension_module = 'cpu' logger.info("Running in %s" % extension_module) ctx = get_extension_context(extension_module, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) if args.tiny_mode: # We use Tiny ImageNet from Stanford CS231N class. # (Tiny ImageNet, https://tiny-imagenet.herokuapp.com/) # Tiny ImageNet consists of 200 categories, each category has 500 images # in training set. The image size is 64x64. To adapt ResNet into 64x64 # image inputs, the input image size of ResNet is set as 56x56, and # the stride in the first conv and the first max pooling are removed. # Please check README. data = data_iterator_tiny_imagenet(args.batch_size, 'train') vdata = data_iterator_tiny_imagenet(args.batch_size, 'val') num_classes = 200 else: # We use ImageNet. # (ImageNet, https://imagenet.herokuapp.com/) # ImageNet consists of 1000 categories, each category has 1280 images # in training set. The image size is various. To adapt ResNet into # 320x320 image inputs, the input image size of ResNet is set as # 224x224. We need to get tar file and create cache file(320x320 images). # Please check README. data = data_iterator_imagenet(args.batch_size, args.train_cachefile_dir) vdata = data_iterator_imagenet(args.batch_size, args.val_cachefile_dir) num_classes = 1000 t_model = get_model(args, num_classes, test=False, tiny=args.tiny_mode) t_model.pred.persistent = True # Not clearing buffer of pred in backward # TODO: need_grad should be passed to get_unlinked_variable after v1.0.3 fix. t_pred2 = t_model.pred.get_unlinked_variable() t_pred2.need_grad = False t_e = F.mean(F.top_n_error(t_pred2, t_model.label)) v_model = get_model(args, num_classes, test=True, tiny=args.tiny_mode) v_model.pred.persistent = True # Not clearing buffer of pred in forward # TODO: need_grad should be passed to get_unlinked_variable after v1.0.3 fix. v_pred2 = v_model.pred.get_unlinked_variable() v_pred2.need_grad = False v_e = F.mean(F.top_n_error(v_pred2, v_model.label)) # Save_nnp_Epoch0 contents = save_nnp({'x': v_model.image}, {'y': v_model.pred}, args.batch_size) save.save(os.path.join(args.model_save_path, 'Imagenet_result_epoch0.nnp'), contents) # Create Solver. solver = S.Momentum(args.learning_rate, 0.9) solver.set_parameters(nn.get_parameters()) start_point = 0 if args.checkpoint is not None: # load weights and solver state info from specified checkpoint file. start_point = load_checkpoint(args.checkpoint, solver) # Create monitor. import nnabla.monitor as M monitor = M.Monitor(args.monitor_path) monitor_loss = M.MonitorSeries("Training loss", monitor, interval=10) monitor_err = M.MonitorSeries("Training error", monitor, interval=10) monitor_vloss = M.MonitorSeries("Validation loss", monitor, interval=10) monitor_verr = M.MonitorSeries("Validation error", monitor, interval=10) monitor_time = M.MonitorTimeElapsed("Training time", monitor, interval=10) monitor_vtime = M.MonitorTimeElapsed("Validation time", monitor, interval=10) # Training loop. for i in range(start_point, args.max_iter): # Save parameters if i % args.model_save_interval == 0: # save checkpoint file save_checkpoint(args.model_save_path, i, solver) # Validation if i % args.val_interval == 0 and i != 0: # Clear all intermediate memory to save memory. # t_model.loss.clear_recursive() l = 0.0 e = 0.0 for j in range(args.val_iter): images, labels = vdata.next() v_model.image.d = images v_model.label.d = labels v_model.image.data.cast(np.uint8, ctx) v_model.label.data.cast(np.int32, ctx) v_model.loss.forward(clear_buffer=True) v_e.forward(clear_buffer=True) l += v_model.loss.d e += v_e.d monitor_vloss.add(i, l / args.val_iter) monitor_verr.add(i, e / args.val_iter) monitor_vtime.add(i) # Clear all intermediate memory to save memory. # v_model.loss.clear_recursive() # Training l = 0.0 e = 0.0 solver.zero_grad() def accumulate_error(l, e, t_model, t_e): l += t_model.loss.d e += t_e.d return l, e # Gradient accumulation loop for j in range(args.accum_grad): images, labels = data.next() t_model.image.d = images t_model.label.d = labels t_model.image.data.cast(np.uint8, ctx) t_model.label.data.cast(np.int32, ctx) t_model.loss.forward(clear_no_need_grad=True) t_model.loss.backward(clear_buffer=True) # Accumulating gradients t_e.forward(clear_buffer=True) l, e = accumulate_error(l, e, t_model, t_e) solver.weight_decay(args.weight_decay) solver.update() monitor_loss.add(i, l / args.accum_grad) monitor_err.add(i, e / args.accum_grad) monitor_time.add(i) # Learning rate decay at scheduled iter if i in args.learning_rate_decay_at: solver.set_learning_rate(solver.learning_rate() * 0.1) nn.save_parameters( os.path.join(args.model_save_path, 'param_%06d.h5' % args.max_iter)) # Save_nnp contents = save_nnp({'x': v_model.image}, {'y': v_model.pred}, args.batch_size) save.save(os.path.join(args.model_save_path, 'Imagenet_result.nnp'), contents)
def train(args): """ Main script. """ # Get context. from nnabla.ext_utils import get_extension_context logger.info("Running in %s" % args.context) ctx = get_extension_context(args.context, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) # Create CNN network for both training and testing. # TRAIN # Fake path z = nn.Variable([args.batch_size, 100, 1, 1]) fake = generator(z) fake.persistent = True # Not to clear at backward pred_fake = discriminator(fake) loss_gen = F.mean( F.sigmoid_cross_entropy(pred_fake, F.constant(1, pred_fake.shape))) fake_dis = fake.get_unlinked_variable(need_grad=True) fake_dis.need_grad = True # TODO: Workaround until v1.0.2 pred_fake_dis = discriminator(fake_dis) loss_dis = F.mean( F.sigmoid_cross_entropy(pred_fake_dis, F.constant(0, pred_fake_dis.shape))) # Real path x = nn.Variable([args.batch_size, 1, 28, 28]) pred_real = discriminator(x) loss_dis += F.mean( F.sigmoid_cross_entropy(pred_real, F.constant(1, pred_real.shape))) # Create Solver. solver_gen = S.Adam(args.learning_rate, beta1=0.5) solver_dis = S.Adam(args.learning_rate, beta1=0.5) with nn.parameter_scope("gen"): solver_gen.set_parameters(nn.get_parameters()) with nn.parameter_scope("dis"): solver_dis.set_parameters(nn.get_parameters()) start_point = 0 if args.checkpoint is not None: # load weights and solver state info from specified checkpoint files. start_point = load_checkpoint(args.checkpoint, { "gen": solver_gen, "dis": solver_dis }) # Create monitor. import nnabla.monitor as M monitor = M.Monitor(args.monitor_path) monitor_loss_gen = M.MonitorSeries("Generator loss", monitor, interval=10) monitor_loss_dis = M.MonitorSeries("Discriminator loss", monitor, interval=10) monitor_time = M.MonitorTimeElapsed("Time", monitor, interval=100) monitor_fake = M.MonitorImageTile("Fake images", monitor, normalize_method=lambda x: (x + 1) / 2.) data = data_iterator_mnist(args.batch_size, True) # Save_nnp contents = save_nnp({'x': z}, {'y': fake}, args.batch_size) save.save( os.path.join(args.model_save_path, 'Generator_result_epoch0.nnp'), contents) contents = save_nnp({'x': x}, {'y': pred_real}, args.batch_size) save.save( os.path.join(args.model_save_path, 'Discriminator_result_epoch0.nnp'), contents) # Training loop. for i in range(start_point, args.max_iter): if i % args.model_save_interval == 0: save_checkpoint(args.model_save_path, i, { "gen": solver_gen, "dis": solver_dis }) # Training forward image, _ = data.next() x.d = image / 255. - 0.5 # [0, 255] to [-1, 1] z.d = np.random.randn(*z.shape) # Generator update. solver_gen.zero_grad() loss_gen.forward(clear_no_need_grad=True) loss_gen.backward(clear_buffer=True) solver_gen.weight_decay(args.weight_decay) solver_gen.update() monitor_fake.add(i, fake) monitor_loss_gen.add(i, loss_gen.d.copy()) # Discriminator update. solver_dis.zero_grad() loss_dis.forward(clear_no_need_grad=True) loss_dis.backward(clear_buffer=True) solver_dis.weight_decay(args.weight_decay) solver_dis.update() monitor_loss_dis.add(i, loss_dis.d.copy()) monitor_time.add(i) with nn.parameter_scope("gen"): nn.save_parameters( os.path.join(args.model_save_path, "generator_param_%06d.h5" % i)) with nn.parameter_scope("dis"): nn.save_parameters( os.path.join(args.model_save_path, "discriminator_param_%06d.h5" % i)) # Save_nnp contents = save_nnp({'x': z}, {'y': fake}, args.batch_size) save.save(os.path.join(args.model_save_path, 'Generator_result.nnp'), contents) contents = save_nnp({'x': x}, {'y': pred_real}, args.batch_size) save.save(os.path.join(args.model_save_path, 'Discriminator_result.nnp'), contents)
def train(): args = get_args() # Set context. from nnabla.ext_utils import get_extension_context logger.info("Running in {}:{}".format(args.context, args.type_config)) ctx = get_extension_context(args.context, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) data_iterator = data_iterator_librispeech(args.batch_size, args.data_dir) _data_source = data_iterator._data_source # dirty hack... # model x = nn.Variable(shape=(args.batch_size, data_config.duration, 1)) # (B, T, 1) onehot = F.one_hot(x, shape=(data_config.q_bit_len, )) # (B, T, C) wavenet_input = F.transpose(onehot, (0, 2, 1)) # (B, C, T) # speaker embedding if args.use_speaker_id: s_id = nn.Variable(shape=(args.batch_size, 1)) with nn.parameter_scope("speaker_embedding"): s_emb = PF.embed(s_id, n_inputs=_data_source.n_speaker, n_features=WavenetConfig.speaker_dims) s_emb = F.transpose(s_emb, (0, 2, 1)) else: s_emb = None net = WaveNet() wavenet_output = net(wavenet_input, s_emb) pred = F.transpose(wavenet_output, (0, 2, 1)) # (B, T, 1) t = nn.Variable(shape=(args.batch_size, data_config.duration, 1)) loss = F.mean(F.softmax_cross_entropy(pred, t)) # for generation prob = F.softmax(pred) # Create Solver. solver = S.Adam(args.learning_rate) solver.set_parameters(nn.get_parameters()) # load checkpoint start_point = 0 if args.checkpoint is not None: # load weights and solver state info from specified checkpoint file. start_point = load_checkpoint(args.checkpoint, solver) # Create monitor. monitor = Monitor(args.monitor_path) monitor_loss = MonitorSeries("Training loss", monitor, interval=10) # setup save env. audio_save_path = os.path.join(os.path.abspath(args.model_save_path), "audio_results") if audio_save_path and not os.path.exists(audio_save_path): os.makedirs(audio_save_path) # save_nnp contents = save_nnp({'x': x}, {'y': wavenet_output}, args.batch_size) save.save( os.path.join(args.model_save_path, 'Speechsynthesis_result_epoch0.nnp'), contents) # Training loop. for i in range(start_point, args.max_iter): # todo: validation x.d, _speaker, t.d = data_iterator.next() if args.use_speaker_id: s_id.d = _speaker.reshape(-1, 1) solver.zero_grad() loss.forward(clear_no_need_grad=True) loss.backward(clear_buffer=True) solver.update() loss.data.cast(np.float32, ctx) monitor_loss.add(i, loss.d.copy()) if i % args.model_save_interval == 0: prob.forward() audios = mu_law_decode(np.argmax(prob.d, axis=-1), quantize=data_config.q_bit_len) # (B, T) save_audio(audios, i, audio_save_path) # save checkpoint file save_checkpoint(audio_save_path, i, solver) # save_nnp contents = save_nnp({'x': x}, {'y': wavenet_output}, args.batch_size) save.save(os.path.join(args.model_save_path, 'Speechsynthesis_result.nnp'), contents)
def train(): """ Main script. Steps: * Parse command line arguments. * Specify contexts for computation. * Initialize DataIterator. * Construct a computation graph for training and one for validation. * Initialize solver and set parameter variables to that. * Create monitor instances for saving and displaying training stats. * Training loop * Computate error rate for validation data (periodically) * Get a next minibatch. * Execute forwardprop * Set parameter gradients zero * Execute backprop. * Solver updates parameters by using gradients computed by backprop. * Compute training error """ # Parse args args = get_args() n_train_samples = 50000 bs_valid = args.batch_size extension_module = args.context ctx = get_extension_context( extension_module, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) if args.net == "cifar10_resnet23": prediction = functools.partial( resnet23_prediction, ncls=10, nmaps=64, act=F.relu) data_iterator = data_iterator_cifar10 if args.net == "cifar100_resnet23": prediction = functools.partial( resnet23_prediction, ncls=100, nmaps=384, act=F.elu) data_iterator = data_iterator_cifar100 # Create training graphs test = False image_train = nn.Variable((args.batch_size, 3, 32, 32)) label_train = nn.Variable((args.batch_size, 1)) pred_train = prediction(image_train, test) loss_train = loss_function(pred_train, label_train) input_image_train = {"image": image_train, "label": label_train} # Create validation graph test = True image_valid = nn.Variable((bs_valid, 3, 32, 32)) pred_valid = prediction(image_valid, test) input_image_valid = {"image": image_valid} # Solvers solver = S.Adam() solver.set_parameters(nn.get_parameters()) start_point = 0 if args.checkpoint is not None: # load weights and solver state info from specified checkpoint file. start_point = load_checkpoint(args.checkpoint, solver) # Create monitor from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed monitor = Monitor(args.monitor_path) monitor_loss = MonitorSeries("Training loss", monitor, interval=10) monitor_err = MonitorSeries("Training error", monitor, interval=10) monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10) monitor_verr = MonitorSeries("Test error", monitor, interval=1) # Data Iterator tdata = data_iterator(args.batch_size, True) vdata = data_iterator(args.batch_size, False) # save_nnp contents = save_nnp({'x': image_valid}, {'y': pred_valid}, args.batch_size) save.save(os.path.join(args.model_save_path, '{}_epoch0_result.nnp'.format(args.net)), contents) # Training-loop for i in range(start_point, args.max_iter): # Validation if i % int(n_train_samples / args.batch_size) == 0: ve = 0. for j in range(args.val_iter): image, label = vdata.next() input_image_valid["image"].d = image pred_valid.forward() ve += categorical_error(pred_valid.d, label) ve /= args.val_iter monitor_verr.add(i, ve) if int(i % args.model_save_interval) == 0: # save checkpoint file save_checkpoint(args.model_save_path, i, solver) # Forward/Zerograd/Backward image, label = tdata.next() input_image_train["image"].d = image input_image_train["label"].d = label loss_train.forward() solver.zero_grad() loss_train.backward() # Solvers update solver.update() e = categorical_error( pred_train.d, input_image_train["label"].d) monitor_loss.add(i, loss_train.d.copy()) monitor_err.add(i, e) monitor_time.add(i) nn.save_parameters(os.path.join(args.model_save_path, 'params_%06d.h5' % (args.max_iter))) # save_nnp_lastepoch contents = save_nnp({'x': image_valid}, {'y': pred_valid}, args.batch_size) save.save(os.path.join(args.model_save_path, '{}_result.nnp'.format(args.net)), contents)