def create_communicator(ignore_error=False): global _current_communicator import nnabla_ext.cudnn from nnabla.ext_utils import get_extension_context extension_module = "cudnn" context = get_extension_context(extension_module) try: logger.log(99, 'Create communicator with contexts {}'.format(context)) _current_communicator = C.MultiProcessDataParalellCommunicator(context) _current_communicator.init() context.device_id = str(_current_communicator.rank % _current_communicator.size) if _current_communicator.size == 1: _current_communicator = None except: if not ignore_error: raise logger.warning("Failed to initialize nnabla.communicators.") _current_communicator = None return _current_communicator
def train(): """ Naive Multi-Device Training NOTE: the communicator exposes low-level interfaces * Parse command line arguments. * Instantiate a communicator and set parameter variables. * Specify contexts for computation. * Initialize DataIterator. * Construct a computation graph for training and one for validation. * Initialize solver and set parameter variables to that. * Create monitor instances for saving and displaying training stats. * Training loop * Computate error rate for validation data (periodically) * Get a next minibatch. * Execute forwardprop * Set parameter gradients zero * Execute backprop. * AllReduce for gradients * Solver updates parameters by using gradients computed by backprop and all reduce. * Compute training error """ # Parse args args = get_args() n_train_samples = 50000 n_valid_samples = 10000 bs_valid = args.batch_size # Create Communicator and Context extension_module = "cudnn" ctx = get_extension_context(extension_module, type_config=args.type_config) comm = C.MultiProcessDataParalellCommunicator(ctx) comm.init() n_devices = comm.size mpi_rank = comm.rank mpi_local_rank = comm.local_rank device_id = mpi_local_rank ctx.device_id = str(device_id) nn.set_default_context(ctx) # Model rng = np.random.RandomState(313) comm_syncbn = comm if args.sync_bn else None if args.net == "cifar10_resnet23": prediction = functools.partial(resnet23_prediction, rng=rng, ncls=10, nmaps=32, act=F.relu, comm=comm_syncbn) data_iterator = data_iterator_cifar10 if args.net == "cifar100_resnet23": prediction = functools.partial(resnet23_prediction, rng=rng, ncls=100, nmaps=384, act=F.elu, comm=comm_syncbn) data_iterator = data_iterator_cifar100 # Create training graphs image_train = nn.Variable((args.batch_size, 3, 32, 32)) label_train = nn.Variable((args.batch_size, 1)) pred_train = prediction(image_train, test=False) pred_train.persistent = True loss_train = (loss_function(pred_train, label_train) / n_devices).apply(persistent=True) error_train = F.mean(F.top_n_error(pred_train, label_train, axis=1)).apply(persistent=True) loss_error_train = F.sink(loss_train, error_train) input_image_train = {"image": image_train, "label": label_train} # Create validation graph image_valid = nn.Variable((bs_valid, 3, 32, 32)) label_valid = nn.Variable((args.batch_size, 1)) pred_valid = prediction(image_valid, test=True) error_valid = F.mean(F.top_n_error(pred_valid, label_valid, axis=1)) input_image_valid = {"image": image_valid, "label": label_valid} # Solvers solver = S.Adam() solver.set_parameters(nn.get_parameters()) base_lr = args.learning_rate warmup_iter = int( 1. * n_train_samples / args.batch_size / n_devices) * args.warmup_epoch warmup_slope = base_lr * (n_devices - 1) / warmup_iter solver.set_learning_rate(base_lr) # Create monitor from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed monitor = Monitor(args.monitor_path) monitor_loss = MonitorSeries("Training loss", monitor, interval=10) monitor_err = MonitorSeries("Training error", monitor, interval=10) monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10) monitor_verr = MonitorSeries("Validation error", monitor, interval=1) monitor_vtime = MonitorTimeElapsed("Validation time", monitor, interval=1) # Data Iterator rng = np.random.RandomState(device_id) _, tdata = data_iterator(args.batch_size, True, rng) vsource, vdata = data_iterator(args.batch_size, False) # loss_error_train.forward() # Training-loop ve = nn.Variable() for i in range(int(args.max_iter / n_devices)): # Validation if i % int(n_train_samples / args.batch_size / n_devices) == 0: ve_local = 0. k = 0 idx = np.random.permutation(n_valid_samples) val_images = vsource.images[idx] val_labels = vsource.labels[idx] for j in range(int(n_valid_samples / n_devices * mpi_rank), int(n_valid_samples / n_devices * (mpi_rank + 1)), bs_valid): image = val_images[j:j + bs_valid] label = val_labels[j:j + bs_valid] if len(image ) != bs_valid: # note that smaller batch is ignored continue input_image_valid["image"].d = image input_image_valid["label"].d = label error_valid.forward(clear_buffer=True) ve_local += error_valid.d.copy() k += 1 ve_local /= k ve.d = ve_local comm.all_reduce(ve.data, division=True, inplace=True) # Save model if device_id == 0: monitor_verr.add(i * n_devices, ve.d.copy()) monitor_vtime.add(i * n_devices) if i % int(args.model_save_interval / n_devices) == 0: nn.save_parameters( os.path.join(args.model_save_path, 'params_%06d.h5' % i)) # Forward/Zerograd image, label = tdata.next() input_image_train["image"].d = image input_image_train["label"].d = label loss_error_train.forward(clear_no_need_grad=True) solver.zero_grad() # Backward/AllReduce backward_and_all_reduce( loss_error_train, comm, with_all_reduce_callback=args.with_all_reduce_callback) # Solvers update solver.update() # Linear Warmup if i <= warmup_iter: lr = base_lr + warmup_slope * i solver.set_learning_rate(lr) if device_id == 0: # loss and error locally, and elapsed time monitor_loss.add(i * n_devices, loss_train.d.copy()) monitor_err.add(i * n_devices, error_train.d.copy()) monitor_time.add(i * n_devices) # exit(0) if device_id == 0: nn.save_parameters( os.path.join(args.model_save_path, 'params_%06d.h5' % (args.max_iter / n_devices)))
def train(): """ Naive Multi-Device Training NOTE: the communicator exposes low-level interfaces * Parse command line arguments. * Instantiate a communicator and set parameter variables. * Specify contexts for computation. * Initialize DataIterator. * Construct a computation graph for training and one for validation. * Initialize solver and set parameter variables to that. * Create monitor instances for saving and displaying training stats. * Training loop * Computate error rate for validation data (periodically) * Get a next minibatch. * Execute forwardprop * Set parameter gradients zero * Execute backprop. * Solver updates parameters by using gradients computed by backprop. * Compute training error """ # Parse args args = get_args() n_train_samples = 50000 bs_valid = args.batch_size rng = np.random.RandomState(313) if args.net == "cifar10_resnet23": prediction = functools.partial(resnet23_prediction, rng=rng, ncls=10, nmaps=64, act=F.relu) data_iterator = data_iterator_cifar10 if args.net == "cifar100_resnet23": prediction = functools.partial(resnet23_prediction, rng=rng, ncls=100, nmaps=384, act=F.elu) data_iterator = data_iterator_cifar100 # Communicator and Context extension_module = "cuda.cudnn" ctx = extension_context(extension_module) comm = C.MultiProcessDataParalellCommunicator(ctx) comm.init() n_devices = comm.size mpi_rank = comm.rank mpi_local_rank = comm.local_rank device_id = mpi_local_rank ctx = extension_context(extension_module, device_id=device_id) nn.set_default_context(ctx) # Create training graphs test = False image_train = nn.Variable((args.batch_size, 3, 32, 32)) label_train = nn.Variable((args.batch_size, 1)) pred_train = prediction(image_train, test) loss_train = loss_function(pred_train, label_train) input_image_train = {"image": image_train, "label": label_train} # add parameters to communicator comm.add_context_and_parameters((ctx, nn.get_parameters())) # Create validation graph test = True image_valid = nn.Variable((bs_valid, 3, 32, 32)) pred_valid = prediction(image_valid, test) input_image_valid = {"image": image_valid} # Solvers solver = S.Adam() solver.set_parameters(nn.get_parameters()) base_lr = args.learning_rate warmup_iter = int( 1. * n_train_samples / args.batch_size / n_devices) * args.warmup_epoch warmup_slope = base_lr * (n_devices - 1) / warmup_iter solver.set_learning_rate(base_lr) # Create monitor from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed monitor = Monitor(args.monitor_path) monitor_loss = MonitorSeries("Training loss", monitor, interval=10) monitor_err = MonitorSeries("Training error", monitor, interval=10) monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10) monitor_verr = MonitorSeries("Test error", monitor, interval=10) # Data Iterator rng = np.random.RandomState(device_id) tdata = data_iterator(args.batch_size, True, rng) vdata = data_iterator(args.batch_size, False) # Training-loop for i in range(int(args.max_iter / n_devices)): # Validation if device_id == 0: if i % int(n_train_samples / args.batch_size / n_devices) == 0: ve = 0. for j in range(args.val_iter): image, label = vdata.next() input_image_valid["image"].d = image pred_valid.forward() ve += categorical_error(pred_valid.d, label) ve /= args.val_iter monitor_verr.add(i * n_devices, ve) if i % int(args.model_save_interval / n_devices) == 0: nn.save_parameters( os.path.join(args.model_save_path, 'params_%06d.h5' % i)) # Forward/Zerograd/Backward image, label = tdata.next() input_image_train["image"].d = image input_image_train["label"].d = label loss_train.forward() solver.zero_grad() loss_train.backward() # Allreduce comm.allreduce(division=False, inplace=False) # Solvers update solver.update() # Linear Warmup if i <= warmup_iter: lr = base_lr + warmup_slope * i solver.set_learning_rate(lr) if device_id == 0: e = categorical_error(pred_train.d, input_image_train["label"].d) monitor_loss.add(i * n_devices, loss_train.d.copy()) monitor_err.add(i * n_devices, e) monitor_time.add(i * n_devices) if device_id == 0: nn.save_parameters( os.path.join(args.model_save_path, 'params_%06d.h5' % (args.max_iter / n_devices)))
import nnabla.communicators as C import numpy as np from nbla_test_utils import list_context from nnabla.contrib.context import extension_context ############################################ # Communicator has to be instantiated here, # otherwise, mpirun fails. ############################################ # Communicator comm = None try: extension_module = "cuda" ctx = extension_context(extension_module) comm = C.MultiProcessDataParalellCommunicator(ctx) comm.init() n_devices = comm.size mpi_rank = comm.rank mpi_local_rank = comm.local_rank device_id = mpi_local_rank ctx.device_id = str(device_id) except: pass ############################################ def ref_reduce(x_data_list, size, division): f = reduce(lambda x, y: x + y, np.arange(size)) + size results = []
def train(): """ Main script. """ args = get_args() _ = nn.load_parameters(args.pretrained_model_path) if args.fine_tune: nnabla.parameter.pop_parameter('decoder/logits/affine/conv/W') nnabla.parameter.pop_parameter('decoder/logits/affine/conv/b') n_train_samples = args.train_samples n_val_samples = args.val_samples distributed = args.distributed compute_acc = args.compute_acc if distributed: # Communicator and Context from nnabla.ext_utils import get_extension_context extension_module = "cudnn" ctx = get_extension_context( extension_module, type_config=args.type_config) comm = C.MultiProcessDataParalellCommunicator(ctx) comm.init() n_devices = comm.size mpi_rank = comm.rank device_id = mpi_rank ctx.device_id = str(device_id) nn.set_default_context(ctx) else: # Get context. from nnabla.ext_utils import get_extension_context extension_module = args.context if args.context is None: extension_module = 'cpu' logger.info("Running in %s" % extension_module) ctx = get_extension_context( extension_module, device_id=args.device_id, type_config=args.type_config) nn.set_default_context(ctx) n_devices = 1 device_id = 0 # training data data = data_iterator_segmentation( args.train_samples, args.batch_size, args.train_dir, args.train_label_dir, target_width=args.image_width, target_height=args.image_height) # validation data vdata = data_iterator_segmentation(args.val_samples, args.batch_size, args.val_dir, args.val_label_dir, target_width=args.image_width, target_height=args.image_height) if distributed: data = data.slice( rng=None, num_of_slices=n_devices, slice_pos=device_id) vdata = vdata.slice( rng=None, num_of_slices=n_devices, slice_pos=device_id) num_classes = args.num_class # Workaround to start with the same initialized weights for all workers. np.random.seed(313) t_model = get_model( args, test=False) t_model.pred.persistent = True # Not clearing buffer of pred in backward t_pred2 = t_model.pred.unlinked() t_e = F.sum(F.top_n_error(t_pred2, t_model.label, axis=1) * t_model.mask) / F.sum(t_model.mask) v_model = get_model( args, test=True) v_model.pred.persistent = True # Not clearing buffer of pred in forward v_pred2 = v_model.pred.unlinked() v_e = F.sum(F.top_n_error(v_pred2, v_model.label, axis=1) * v_model.mask) / F.sum(t_model.mask) # Create Solver solver = S.Momentum(args.learning_rate, 0.9) solver.set_parameters(nn.get_parameters()) # Load checkpoint start_point = 0 if args.checkpoint is not None: # load weights and solver state info from specified checkpoint file. start_point = load_checkpoint(args.checkpoint, solver) # Setting warmup. base_lr = args.learning_rate / n_devices warmup_iter = int(1. * n_train_samples / args.batch_size / args.accum_grad / n_devices) * args.warmup_epoch warmup_slope = base_lr * (n_devices - 1) / warmup_iter solver.set_learning_rate(base_lr) # Create monitor import nnabla.monitor as M monitor = M.Monitor(args.monitor_path) monitor_loss = M.MonitorSeries("Training loss", monitor, interval=10) monitor_err = M.MonitorSeries("Training error", monitor, interval=10) monitor_vloss = M.MonitorSeries("Validation loss", monitor, interval=1) monitor_verr = M.MonitorSeries("Validation error", monitor, interval=1) monitor_time = M.MonitorTimeElapsed("Training time", monitor, interval=10) monitor_miou = M.MonitorSeries("mean IOU", monitor, interval=10) monitor_vtime = M.MonitorTimeElapsed( "Validation time", monitor, interval=1) # save_nnp contents = save_nnp({'x': v_model.image}, { 'y': v_model.pred}, args.batch_size) save.save(os.path.join(args.model_save_path, 'Deeplabv3plus_result_epoch0.nnp'), contents, variable_batch_size=False) # Training loop for i in range(start_point, int(args.max_iter / n_devices)): # Save parameters if i % (args.model_save_interval // n_devices) == 0 and device_id == 0: save_checkpoint(args.model_save_path, i, solver) # Validation if i % (args.val_interval // n_devices) == 0 and i != 0: vmiou_local = 0. val_iter_local = n_val_samples // args.batch_size vl_local = nn.NdArray() vl_local.zero() ve_local = nn.NdArray() ve_local.zero() for j in range(val_iter_local): images, labels, masks = vdata.next() v_model.image.d = images v_model.label.d = labels v_model.mask.d = masks v_model.image.data.cast(np.float32, ctx) v_model.label.data.cast(np.int32, ctx) v_model.loss.forward(clear_buffer=True) v_e.forward(clear_buffer=True) vl_local += v_model.loss.data ve_local += v_e.data # Mean IOU computation if compute_acc: vmiou_local += compute_miou(num_classes, labels, np.argmax(v_model.pred.d, axis=1), masks) vl_local /= val_iter_local ve_local /= val_iter_local if compute_acc: vmiou_local /= val_iter_local vmiou_ndarray = nn.NdArray.from_numpy_array( np.array(vmiou_local)) if distributed: comm.all_reduce(vl_local, division=True, inplace=True) comm.all_reduce(ve_local, division=True, inplace=True) if compute_acc: comm.all_reduce(vmiou_ndarray, division=True, inplace=True) if device_id == 0: monitor_vloss.add(i * n_devices, vl_local.data.copy()) monitor_verr.add(i * n_devices, ve_local.data.copy()) if compute_acc: monitor_miou.add(i * n_devices, vmiou_local) monitor_vtime.add(i * n_devices) # Training l = 0.0 e = 0.0 solver.zero_grad() e_acc = nn.NdArray(t_e.shape) e_acc.zero() l_acc = nn.NdArray(t_model.loss.shape) l_acc.zero() # Gradient accumulation loop for j in range(args.accum_grad): images, labels, masks = data.next() t_model.image.d = images t_model.label.d = labels t_model.mask.d = masks t_model.image.data.cast(np.float32, ctx) t_model.label.data.cast(np.int32, ctx) t_model.loss.forward(clear_no_need_grad=True) t_model.loss.backward(clear_buffer=True) # Accumulating gradients t_e.forward(clear_buffer=True) e_acc += t_e.data l_acc += t_model.loss.data # AllReduce if distributed: params = [x.grad for x in nn.get_parameters().values()] comm.all_reduce(params, division=False, inplace=False) comm.all_reduce(l_acc, division=True, inplace=True) comm.all_reduce(e_acc, division=True, inplace=True) solver.scale_grad(1./args.accum_grad) solver.weight_decay(args.weight_decay) solver.update() # Linear Warmup if i <= warmup_iter: lr = base_lr + warmup_slope * i solver.set_learning_rate(lr) if distributed: # Synchronize by averaging the weights over devices using allreduce if (i+1) % args.sync_weight_every_itr == 0: weights = [x.data for x in nn.get_parameters().values()] comm.all_reduce(weights, division=True, inplace=True) if device_id == 0: monitor_loss.add( i * n_devices, (l_acc / args.accum_grad).data.copy()) monitor_err.add( i * n_devices, (e_acc / args.accum_grad).data.copy()) monitor_time.add(i * n_devices) # Learning rate decay at scheduled iter --> changed to poly learning rate decay policy # if i in args.learning_rate_decay_at: solver.set_learning_rate(base_lr * ((1 - i / args.max_iter)**0.1)) if device_id == 0: nn.save_parameters(os.path.join(args.model_save_path, 'param_%06d.h5' % args.max_iter)) contents = save_nnp({'x': v_model.image}, { 'y': v_model.pred}, args.batch_size) save.save(os.path.join(args.model_save_path, 'Deeplabv3plus_result.nnp'), contents, variable_batch_size=False)
def comm_nccl_opts(request): """Common resources for communicator tests. """ if not request.config.getoption('--test-communicator'): return None import nnabla.communicators as C from nnabla.ext_utils import get_extension_context try: from nnabla_ext import cuda except Exception as e: raise ImportError( "Communicator test requires CUDA extension.\n{}".format(e)) gpus = request.config.getoption('--communicator-gpus') n_devices = cuda.get_device_count() if gpus is None: devices = list(map(str, range(n_devices))) else: devices = gpu.split(',') # Check numbers try: for d in devices: gid = int(d) if gid >= n_devices: raise ValueError('') except ValueError as e: raise ValueError( "GPU IDs must be comma sperated integers of available GPUs. Given {}. Avaiable GPUs are {}.".format(gpus, n_devices)) extension_module = "cuda" ctx = get_extension_context(extension_module) try: comm = C.MultiProcessDataParalellCommunicator(ctx) except Exception as e: raise RuntimeError( "Communicator could not be created. You may haven't build with distributed support.\n{}".format(e)) try: comm.init() except Exception as e: raise RuntimeError( "Communicator initialization failed. (Maybe MPI init failure.)\n{}".format(e)) assert len( devices) == comm.size, "Number of cuda devices used are not same as that of processes." n_devices = comm.size mpi_rank = comm.rank mpi_local_rank = comm.local_rank ctx.device_id = devices[mpi_local_rank] class CommOpts: pass c = CommOpts() c.comm = comm c.device_id = ctx.device_id c.devices = devices c.mpi_rank = mpi_rank c.mpi_local_rank = mpi_local_rank return c
def train(args): # Communicator and Context extension_module = "cudnn" ctx = get_extension_context(extension_module, type_config=args.type_config) comm = C.MultiProcessDataParalellCommunicator(ctx) comm.init() n_devices = comm.size mpi_rank = comm.rank device_id = comm.local_rank ctx.device_id = str(device_id) nn.set_default_context(ctx) # Args latent = args.latent maps = args.maps batch_size = args.batch_size image_size = args.image_size n_classes = args.n_classes not_sn = args.not_sn # Model # workaround to start with the same weights in the distributed system. np.random.seed(412) # generator loss z = nn.Variable([batch_size, latent]) y_fake = nn.Variable([batch_size]) x_fake = generator(z, y_fake, maps=maps, n_classes=n_classes, sn=not_sn).apply(persistent=True) p_fake = discriminator(x_fake, y_fake, maps=maps // 16, n_classes=n_classes, sn=not_sn) loss_gen = gan_loss(p_fake) # discriminator loss y_real = nn.Variable([batch_size]) x_real = nn.Variable([batch_size, 3, image_size, image_size]) p_real = discriminator(x_real, y_real, maps=maps // 16, n_classes=n_classes, sn=not_sn) loss_dis = gan_loss(p_fake, p_real) # generator with fixed value for test z_test = nn.Variable.from_numpy_array(np.random.randn(batch_size, latent)) y_test = nn.Variable.from_numpy_array( generate_random_class(n_classes, batch_size)) x_test = generator(z_test, y_test, maps=maps, n_classes=n_classes, test=True, sn=not_sn) # Solver solver_gen = S.Adam(args.lrg, args.beta1, args.beta2) solver_dis = S.Adam(args.lrd, args.beta1, args.beta2) with nn.parameter_scope("generator"): params_gen = nn.get_parameters() solver_gen.set_parameters(params_gen) with nn.parameter_scope("discriminator"): params_dis = nn.get_parameters() solver_dis.set_parameters(params_dis) # Monitor if comm.rank == 0: monitor = Monitor(args.monitor_path) monitor_loss_gen = MonitorSeries( "Generator Loss", monitor, interval=10) monitor_loss_dis = MonitorSeries( "Discriminator Loss", monitor, interval=10) monitor_time = MonitorTimeElapsed( "Training Time", monitor, interval=10) monitor_image_tile_train = MonitorImageTile("Image Tile Train", monitor, num_images=args.batch_size, interval=1, normalize_method=normalize_method) monitor_image_tile_test = MonitorImageTile("Image Tile Test", monitor, num_images=args.batch_size, interval=1, normalize_method=normalize_method) # DataIterator rng = np.random.RandomState(device_id) di = data_iterator_imagenet(args.train_dir, args.dirname_to_label_path, args.batch_size, n_classes=args.n_classes, rng=rng) # Train loop for i in range(args.max_iter): # Train discriminator x_fake.need_grad = False # no need for discriminator backward solver_dis.zero_grad() for _ in range(args.accum_grad): # feed x_real and y_real x_data, y_data = di.next() x_real.d, y_real.d = x_data, y_data.flatten() # feed z and y_fake z_data = np.random.randn(args.batch_size, args.latent) y_data = generate_random_class(args.n_classes, args.batch_size) z.d, y_fake.d = z_data, y_data loss_dis.forward(clear_no_need_grad=True) loss_dis.backward( 1.0 / (args.accum_grad * n_devices), clear_buffer=True) comm.all_reduce([v.grad for v in params_dis.values()]) solver_dis.update() # Train genrator x_fake.need_grad = True # need for generator backward solver_gen.zero_grad() for _ in range(args.accum_grad): z_data = np.random.randn(args.batch_size, args.latent) y_data = generate_random_class(args.n_classes, args.batch_size) z.d, y_fake.d = z_data, y_data loss_gen.forward(clear_no_need_grad=True) loss_gen.backward( 1.0 / (args.accum_grad * n_devices), clear_buffer=True) comm.all_reduce([v.grad for v in params_gen.values()]) solver_gen.update() # Synchronize by averaging the weights over devices using allreduce if i % args.sync_weight_every_itr == 0: weights = [v.data for v in nn.get_parameters().values()] comm.all_reduce(weights, division=True, inplace=True) # Save model and image if i % args.save_interval == 0 and comm.rank == 0: x_test.forward(clear_buffer=True) nn.save_parameters(os.path.join( args.monitor_path, "params_{}.h5".format(i))) monitor_image_tile_train.add(i, x_fake.d) monitor_image_tile_test.add(i, x_test.d) # Monitor if comm.rank == 0: monitor_loss_gen.add(i, loss_gen.d.copy()) monitor_loss_dis.add(i, loss_dis.d.copy()) monitor_time.add(i) if comm.rank == 0: x_test.forward(clear_buffer=True) nn.save_parameters(os.path.join( args.monitor_path, "params_{}.h5".format(i))) monitor_image_tile_train.add(i, x_fake.d) monitor_image_tile_test.add(i, x_test.d)
def train(args): # Create Communicator and Context extension_module = "cudnn" ctx = get_extension_context(extension_module, type_config=args.type_config) comm = C.MultiProcessDataParalellCommunicator(ctx) comm.init() n_devices = comm.size mpi_rank = comm.rank mpi_local_rank = comm.local_rank device_id = mpi_local_rank ctx.device_id = str(device_id) nn.set_default_context(ctx) # Input b, c, h, w = args.batch_size, 3, args.image_size, args.image_size x_real_a = nn.Variable([b, c, h, w]) x_real_b = nn.Variable([b, c, h, w]) # Model # workaround for starting with the same model among devices. np.random.seed(412) maps = args.maps # within-domain reconstruction (domain A) x_content_a = content_encoder(x_real_a, maps, name="content-encoder-a") x_style_a = style_encoder(x_real_a, maps, name="style-encoder-a") x_recon_a = decoder(x_content_a, x_style_a, name="decoder-a") # within-domain reconstruction (domain B) x_content_b = content_encoder(x_real_b, maps, name="content-encoder-b") x_style_b = style_encoder(x_real_b, maps, name="style-encoder-b") x_recon_b = decoder(x_content_b, x_style_b, name="decoder-b") # generate over domains and reconstruction of content and style (domain A) z_style_a = F.randn(shape=x_style_a.shape) x_fake_a = decoder(x_content_b, z_style_a, name="decoder-a") x_content_rec_b = content_encoder(x_fake_a, maps, name="content-encoder-a") x_style_rec_a = style_encoder(x_fake_a, maps, name="style-encoder-a") # generate over domains and reconstruction of content and style (domain B) z_style_b = F.randn(shape=x_style_b.shape) x_fake_b = decoder(x_content_a, z_style_b, name="decoder-b") x_content_rec_a = content_encoder(x_fake_b, maps, name="content-encoder-b") x_style_rec_b = style_encoder(x_fake_b, maps, name="style-encoder-b") # discriminate (domain A) p_x_fake_a_list = discriminators(x_fake_a) p_x_real_a_list = discriminators(x_real_a) p_x_fake_b_list = discriminators(x_fake_b) p_x_real_b_list = discriminators(x_real_b) # Loss # within-domain reconstruction loss_recon_x_a = recon_loss(x_recon_a, x_real_a).apply(persistent=True) loss_recon_x_b = recon_loss(x_recon_b, x_real_b).apply(persistent=True) # content and style reconstruction loss_recon_x_style_a = recon_loss(x_style_rec_a, z_style_a).apply(persistent=True) loss_recon_x_content_b = recon_loss(x_content_rec_b, x_content_b).apply(persistent=True) loss_recon_x_style_b = recon_loss(x_style_rec_b, z_style_b).apply(persistent=True) loss_recon_x_content_a = recon_loss(x_content_rec_a, x_content_a).apply(persistent=True) # adversarial def f(x, y): return x + y loss_gen_a = reduce(f, [lsgan_loss(p_f) for p_f in p_x_fake_a_list]).apply(persistent=True) loss_dis_a = reduce(f, [ lsgan_loss(p_f, p_r) for p_f, p_r in zip(p_x_fake_a_list, p_x_real_a_list) ]).apply(persistent=True) loss_gen_b = reduce(f, [lsgan_loss(p_f) for p_f in p_x_fake_b_list]).apply(persistent=True) loss_dis_b = reduce(f, [ lsgan_loss(p_f, p_r) for p_f, p_r in zip(p_x_fake_b_list, p_x_real_b_list) ]).apply(persistent=True) # loss for generator-related models loss_gen = loss_gen_a + loss_gen_b \ + args.lambda_x * (loss_recon_x_a + loss_recon_x_b) \ + args.lambda_c * (loss_recon_x_content_a + loss_recon_x_content_b) \ + args.lambda_s * (loss_recon_x_style_a + loss_recon_x_style_b) # loss for discriminators loss_dis = loss_dis_a + loss_dis_b # Solver lr_g, lr_d, beta1, beta2 = args.lr_g, args.lr_d, args.beta1, args.beta2 # solver for generator-related models solver_gen = S.Adam(lr_g, beta1, beta2) with nn.parameter_scope("generator"): params_gen = nn.get_parameters() solver_gen.set_parameters(params_gen) # solver for discriminators solver_dis = S.Adam(lr_d, beta1, beta2) with nn.parameter_scope("discriminators"): params_dis = nn.get_parameters() solver_dis.set_parameters(params_dis) # Monitor monitor = Monitor(args.monitor_path) # time monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10) # reconstruction monitor_loss_recon_x_a = MonitorSeries("Recon Loss Image A", monitor, interval=10) monitor_loss_recon_x_content_b = MonitorSeries("Recon Loss Content B", monitor, interval=10) monitor_loss_recon_x_style_a = MonitorSeries("Recon Loss Style A", monitor, interval=10) monitor_loss_recon_x_b = MonitorSeries("Recon Loss Image B", monitor, interval=10) monitor_loss_recon_x_content_a = MonitorSeries("Recon Loss Content A", monitor, interval=10) monitor_loss_recon_x_style_b = MonitorSeries("Recon Loss Style B", monitor, interval=10) # adversarial monitor_loss_gen_a = MonitorSeries("Gen Loss A", monitor, interval=10) monitor_loss_dis_a = MonitorSeries("Dis Loss A", monitor, interval=10) monitor_loss_gen_b = MonitorSeries("Gen Loss B", monitor, interval=10) monitor_loss_dis_b = MonitorSeries("Dis Loss B", monitor, interval=10) monitor_losses = [ # reconstruction (monitor_loss_recon_x_a, loss_recon_x_a), (monitor_loss_recon_x_content_b, loss_recon_x_content_b), (monitor_loss_recon_x_style_a, loss_recon_x_style_a), (monitor_loss_recon_x_b, loss_recon_x_b), (monitor_loss_recon_x_content_a, loss_recon_x_content_a), (monitor_loss_recon_x_style_b, loss_recon_x_style_b), # adaversarial (monitor_loss_gen_a, loss_gen_a), (monitor_loss_dis_a, loss_dis_a), (monitor_loss_gen_b, loss_gen_b), (monitor_loss_dis_b, loss_dis_b) ] # image monitor_image_a = MonitorImage("Fake Image B to A Train", monitor, interval=1) monitor_image_b = MonitorImage("Fake Image A to B Train", monitor, interval=1) monitor_images = [ (monitor_image_a, x_fake_a), (monitor_image_b, x_fake_b), ] # DataIterator rng_a = np.random.RandomState(device_id) rng_b = np.random.RandomState(device_id + n_devices) di_a = munit_data_iterator(args.img_path_a, args.batch_size, rng=rng_a) di_b = munit_data_iterator(args.img_path_b, args.batch_size, rng=rng_b) # Train for i in range(args.max_iter // n_devices): ii = i * n_devices # Train generator-related models x_data_a, x_data_b = di_a.next()[0], di_b.next()[0] x_real_a.d, x_real_b.d = x_data_a, x_data_b solver_gen.zero_grad() loss_gen.forward(clear_no_need_grad=True) loss_gen.backward(clear_buffer=True) comm.all_reduce([w.grad for w in params_gen.values()]) solver_gen.weight_decay(args.weight_decay_rate) solver_gen.update() # Train discriminators x_data_a, x_data_b = di_a.next()[0], di_b.next()[0] x_real_a.d, x_real_b.d = x_data_a, x_data_b x_fake_a.need_grad, x_fake_b.need_grad = False, False solver_dis.zero_grad() loss_dis.forward(clear_no_need_grad=True) loss_dis.backward(clear_buffer=True) comm.all_reduce([w.grad for w in params_dis.values()]) solver_dis.weight_decay(args.weight_decay_rate) solver_dis.update() x_fake_a.need_grad, x_fake_b.need_grad = True, True # LR schedule if (i + 1) % (args.lr_decay_at_every // n_devices) == 0: lr_d = solver_dis.learning_rate() * args.lr_decay_rate lr_g = solver_gen.learning_rate() * args.lr_decay_rate solver_dis.set_learning_rate(lr_d) solver_gen.set_learning_rate(lr_g) if mpi_local_rank == 0: # Monitor monitor_time.add(ii) for mon, loss in monitor_losses: mon.add(ii, loss.d) # Save if (i + 1) % (args.model_save_interval // n_devices) == 0: for mon, x in monitor_images: mon.add(ii, x.d) nn.save_parameters( os.path.join(args.monitor_path, "param_{:05d}.h5".format(i))) if mpi_local_rank == 0: # Monitor for mon, loss in monitor_losses: mon.add(ii, loss.d) # Save for mon, x in monitor_images: mon.add(ii, x.d) nn.save_parameters( os.path.join(args.monitor_path, "param_{:05d}.h5".format(i)))
def train(): """ Main script. Naive Multi-Device Training NOTE: the communicator exposes low-level interfaces * Parse command line arguments. * Instantiate a communicator and set parameter variables. * Specify contexts for computation. * Initialize DataIterator. * Construct a computation graph for training and one for validation. * Initialize solver and set parameter variables to that. * Create monitor instances for saving and displaying training stats. * Training loop * Computate error rate for validation data (periodically) * Get a next minibatch. * Execute forwardprop * Set parameter gradients zero * Execute backprop. * Inplace allreduce (THIS IS THE MAIN difference from a single device training) * Solver updates parameters by using gradients computed by backprop. * Compute training error """ args = get_args() if args.tiny_mode: n_train_samples = 100000 else: n_train_samples = 1282167 # Communicator and Context from nnabla.ext_utils import get_extension_context extension_module = "cudnn" ctx = get_extension_context(extension_module, type_config=args.type_config) comm = C.MultiProcessDataParalellCommunicator(ctx) comm.init() n_devices = comm.size mpi_rank = comm.rank device_id = mpi_rank ctx.device_id = str(device_id) nn.set_default_context(ctx) # workarond to start with the same parameters. rng = np.random.RandomState(device_id) if args.tiny_mode: # We use Tiny ImageNet from Stanford CS231N class. # (Tiny ImageNet, https://tiny-imagenet.herokuapp.com/) # Tiny ImageNet consists of 200 categories, each category has 500 images # in training set. The image size is 64x64. To adapt ResNet into 64x64 # image inputs, the input image size of ResNet is set as 56x56, and # the stride in the first conv and the first max pooling are removed. # Please check README. data = data_iterator_tiny_imagenet(args.batch_size, 'train') vdata = data_iterator_tiny_imagenet(args.batch_size, 'val') num_classes = 200 else: # We use ImageNet. # (ImageNet, https://imagenet.herokuapp.com/) # ImageNet consists of 1000 categories, each category has 1280 images # in training set. The image size is various. To adapt ResNet into # 320x320 image inputs, the input image size of ResNet is set as # 224x224. We need to get tar file and create cache file(320x320 images). # Please check README. data = data_iterator_imagenet(args.batch_size, args.train_cachefile_dir, rng=rng) vdata = data_iterator_imagenet(args.batch_size, args.val_cachefile_dir) vdata = vdata.slice(rng=None, num_of_slices=n_devices, slice_pos=device_id) num_classes = 1000 # Workaround to start with the same initialized weights for all workers. np.random.seed(313) t_model = get_model(args, num_classes, test=False, tiny=args.tiny_mode) t_model.pred.persistent = True # Not clearing buffer of pred in backward t_pred2 = t_model.pred.unlinked() t_e = F.mean(F.top_n_error(t_pred2, t_model.label)) v_model = get_model(args, num_classes, test=True, tiny=args.tiny_mode) v_model.pred.persistent = True # Not clearing buffer of pred in forward v_pred2 = v_model.pred.unlinked() v_e = F.mean(F.top_n_error(v_pred2, v_model.label)) # Add parameters to communicator. comm.add_context_and_parameters((ctx, nn.get_parameters())) # Create Solver. solver = S.Momentum(args.learning_rate, 0.9) solver.set_parameters(nn.get_parameters()) # Setting warmup. base_lr = args.learning_rate / n_devices warmup_iter = int(1. * n_train_samples / args.batch_size / args.accum_grad / n_devices) * args.warmup_epoch warmup_slope = base_lr * (n_devices - 1) / warmup_iter solver.set_learning_rate(base_lr) # Create monitor. import nnabla.monitor as M monitor = M.Monitor(args.monitor_path) monitor_loss = M.MonitorSeries("Training loss", monitor, interval=10) monitor_err = M.MonitorSeries("Training error", monitor, interval=10) monitor_vloss = M.MonitorSeries("Validation loss", monitor, interval=1) monitor_verr = M.MonitorSeries("Validation error", monitor, interval=1) monitor_time = M.MonitorTimeElapsed("Training time", monitor, interval=10) monitor_vtime = M.MonitorTimeElapsed("Validation time", monitor, interval=1) # Training loop. vl = nn.Variable() ve = nn.Variable() for i in range(int(args.max_iter / n_devices)): # Save parameters if i % (args.model_save_interval // n_devices) == 0 and device_id == 0: nn.save_parameters( os.path.join(args.model_save_path, 'param_%06d.h5' % i)) # Validation if i % (args.val_interval // n_devices) == 0 and i != 0: ve_local = 0. vl_local = 0. val_iter_local = args.val_iter // n_devices for j in range(val_iter_local): images, labels = vdata.next() v_model.image.d = images v_model.label.d = labels v_model.image.data.cast(np.uint8, ctx) v_model.label.data.cast(np.int32, ctx) v_model.loss.forward(clear_buffer=True) v_e.forward(clear_buffer=True) vl_local += v_model.loss.d.copy() ve_local += v_e.d.copy() vl_local /= val_iter_local vl.d = vl_local comm.all_reduce(vl.data, division=True, inplace=True) ve_local /= val_iter_local ve.d = ve_local comm.all_reduce(ve.data, division=True, inplace=True) if device_id == 0: monitor_vloss.add(i * n_devices, vl.d.copy()) monitor_verr.add(i * n_devices, ve.d.copy()) monitor_vtime.add(i * n_devices) # Training l = 0.0 e = 0.0 solver.zero_grad() def accumulate_error(l, e, t_model, t_e): l += t_model.loss.d e += t_e.d return l, e # Gradient accumulation loop for j in range(args.accum_grad): images, labels = data.next() if j != 0: # Update e and l according to previous results of forward # propagation. # The update of last iteration is performed # after solver update to avoid unnecessary CUDA synchronization. # This is performed after data.next() in order to overlap # the data loading and graph execution. # TODO: Move this to the bottom of the loop when prefetch # data loader is available. l, e = accumulate_error(l, e, t_model, t_e) t_model.image.d = images t_model.label.d = labels t_model.image.data.cast(np.uint8, ctx) t_model.label.data.cast(np.int32, ctx) t_model.loss.forward(clear_no_need_grad=True) t_model.loss.backward(clear_buffer=True) # Accumulating gradients t_e.forward(clear_buffer=True) # AllReduce params = [x.grad for x in nn.get_parameters().values()] comm.all_reduce(params, division=False, inplace=False) # Update solver.weight_decay(args.weight_decay) solver.update() # Accumulate errors after solver update l, e = accumulate_error(l, e, t_model, t_e) # Linear Warmup if i <= warmup_iter: lr = base_lr + warmup_slope * i solver.set_learning_rate(lr) # Synchronize by averaging the weights over devices using allreduce if (i + 1) % args.sync_weight_every_itr == 0: weights = [x.data for x in nn.get_parameters().values()] comm.all_reduce(weights, division=True, inplace=True) if device_id == 0: monitor_loss.add(i * n_devices, l / args.accum_grad) monitor_err.add(i * n_devices, e / args.accum_grad) monitor_time.add(i * n_devices) # Learning rate decay at scheduled iter if i * n_devices in args.learning_rate_decay_at: solver.set_learning_rate(solver.learning_rate() * 0.1) if device_id == 0: nn.save_parameters( os.path.join(args.model_save_path, 'param_%06d.h5' % (args.max_iter / n_devices)))
def train(): """ Main script. Naive Multi-Device Training NOTE: the communicator exposes low-level interfaces * Parse command line arguments. * Instantiate a communicator and set parameter variables. * Specify contexts for computation. * Initialize DataIterator. * Construct a computation graph for training and one for validation. * Initialize solver and set parameter variables to that. * Create monitor instances for saving and displaying training stats. * Training loop * Computate error rate for validation data (periodically) * Get a next minibatch. * Execute forwardprop * Set parameter gradients zero * Execute backprop. * Inplace allreduce (THIS IS THE MAIN difference from a single device training) * Solver updates parameters by using gradients computed by backprop. * Compute training error """ args = get_args() n_train_samples = 1281167 num_classes = 1000 # Communicator and Context from nnabla.ext_utils import get_extension_context extension_module = "cudnn" ctx = get_extension_context(extension_module, type_config=args.type_config) comm = C.MultiProcessDataParalellCommunicator(ctx) comm.init() n_devices = comm.size mpi_rank = comm.rank device_id = mpi_rank ctx.device_id = str(device_id) nn.set_default_context(ctx) # Pipelines and Iterators for training train_pipes = [ TrainPipeline(args.batch_size, args.num_threads, device_id, args.train_cachefile_dir, args.train_list, seed=device_id + 1, num_gpu=n_devices, random_area=args.random_area) ] train_pipes[0].build() data = DALIClassificationIterator(train_pipes, train_pipes[0].epoch_size("Reader") // n_devices, auto_reset=True, stop_at_epoch=False) # Pipelines and Iterators for validation val_pipes = [ ValPipeline(args.batch_size, args.num_threads, device_id, args.val_cachefile_dir, args.val_list, seed=device_id + 1, num_gpu=n_devices) ] val_pipes[0].build() vdata = DALIClassificationIterator(val_pipes, val_pipes[0].epoch_size("Reader") // n_devices, auto_reset=True, stop_at_epoch=False) # Network for training t_model = get_model(args, num_classes, n_devices, args.accum_grad, test=False) t_model.pred.persistent = True # Not clearing buffer of pred in backward t_pred2 = t_model.pred.get_unlinked_variable(need_grad=False) t_e = F.mean(F.top_n_error(t_pred2, t_model.label)) # Network for validation v_model = get_model(args, num_classes, n_devices, args.accum_grad, test=True) v_model.pred.persistent = True # Not clearing buffer of pred in forward v_pred2 = v_model.pred.get_unlinked_variable(need_grad=False) v_e = F.mean(F.top_n_error(v_pred2, v_model.label)) # Solver solver = S.Momentum(args.learning_rate, 0.9) solver.set_learning_rate(args.learning_rate) solver.set_parameters(nn.get_parameters()) # Monitors import nnabla.monitor as M monitor = M.Monitor(args.monitor_path) monitor_loss = M.MonitorSeries("Training loss", monitor, interval=10) monitor_err = M.MonitorSeries("Training error", monitor, interval=10) monitor_vloss = M.MonitorSeries("Validation loss", monitor, interval=1) monitor_verr = M.MonitorSeries("Validation error", monitor, interval=1) monitor_time = M.MonitorTimeElapsed("Training time", monitor, interval=10) monitor_vtime = M.MonitorTimeElapsed("Validation time", monitor, interval=1) # Training loop vl = nn.Variable() ve = nn.Variable() for i in range(int(args.max_iter / n_devices)): # Save parameters if i % (args.model_save_interval // n_devices) == 0 and device_id == 0: nn.save_parameters( os.path.join(args.model_save_path, 'param_%06d.h5' % i)) # Validation if i % (args.val_interval // n_devices) == 0 and i != 0: ve_local = 0. vl_local = 0. val_iter_local = args.val_iter // n_devices for j in range(val_iter_local): nextImage, nextLabel = vdata.next() v_model.image.data = nextImage v_model.label.data = nextLabel v_model.loss.forward(clear_buffer=True) v_e.forward(clear_buffer=True) vl_local += v_model.loss.d.copy() ve_local += v_e.d.copy() vl_local /= val_iter_local vl.d = vl_local comm.all_reduce(vl.data, division=True, inplace=True) ve_local /= val_iter_local ve.d = ve_local comm.all_reduce(ve.data, division=True, inplace=True) if device_id == 0: monitor_vloss.add(i * n_devices, vl.d.copy()) monitor_verr.add(i * n_devices, ve.d.copy()) monitor_vtime.add(i * n_devices) # Training l = 0.0 e = 0.0 solver.zero_grad() def accumulate_error(l, e, t_model, t_e): l += t_model.loss.d e += t_e.d return l, e # Gradient accumulation loop for j in range(args.accum_grad): nextImage, nextLabel = data.next() t_model.image.data = nextImage t_model.label.data = nextLabel t_model.loss.forward(clear_no_need_grad=True) t_model.loss.backward(clear_buffer=True) # Accumulating gradients t_e.forward(clear_buffer=True) l, e = accumulate_error(l, e, t_model, t_e) # AllReduce params = [x.grad for x in nn.get_parameters().values()] comm.all_reduce(params, division=False, inplace=False) # Update solver.weight_decay(args.weight_decay) solver.update() if device_id == 0: monitor_loss.add(i * n_devices, l / args.accum_grad) monitor_err.add(i * n_devices, e / args.accum_grad) monitor_time.add(i * n_devices) # Learning rate decay at scheduled iter if i * n_devices in args.learning_rate_decay_at: solver.set_learning_rate(solver.learning_rate() * 0.1) if device_id == 0: nn.save_parameters( os.path.join(args.model_save_path, 'param_%06d.h5' % (args.max_iter / n_devices)))
def train(args): # get context ctx = get_extension_context(args.context) comm = C.MultiProcessDataParalellCommunicator(ctx) comm.init() n_devices = comm.size mpi_rank = comm.rank device_id = mpi_rank ctx.device_id = str(device_id) nn.set_default_context(ctx) config = read_yaml(args.config) if args.info: config.monitor_params.info = args.info if comm.size == 1: comm = None else: # disable outputs from logger except its rank = 0 if comm.rank > 0: import logging logger.setLevel(logging.ERROR) test = False train_params = config.train_params dataset_params = config.dataset_params model_params = config.model_params loss_flags = get_loss_flags(train_params) start_epoch = 0 rng = np.random.RandomState(device_id) data_iterator = frame_data_iterator( root_dir=dataset_params.root_dir, frame_shape=dataset_params.frame_shape, id_sampling=dataset_params.id_sampling, is_train=True, random_seed=rng, augmentation_params=dataset_params.augmentation_params, batch_size=train_params['batch_size'], shuffle=True, with_memory_cache=False, with_file_cache=False) if n_devices > 1: data_iterator = data_iterator.slice(rng=rng, num_of_slices=comm.size, slice_pos=comm.rank) # workaround not to use memory cache data_iterator._data_source._on_memory = False logger.info("Disabled on memory data cache.") bs, h, w, c = [train_params.batch_size] + dataset_params.frame_shape source = nn.Variable((bs, c, h, w)) driving = nn.Variable((bs, c, h, w)) with nn.parameter_scope("kp_detector"): # kp_X = {"value": Variable((bs, 10, 2)), "jacobian": Variable((bs, 10, 2, 2))} kp_source = detect_keypoint(source, **model_params.kp_detector_params, **model_params.common_params, test=test, comm=comm) persistent_all(kp_source) kp_driving = detect_keypoint(driving, **model_params.kp_detector_params, **model_params.common_params, test=test, comm=comm) persistent_all(kp_driving) with nn.parameter_scope("generator"): generated = occlusion_aware_generator(source, kp_source=kp_source, kp_driving=kp_driving, **model_params.generator_params, **model_params.common_params, test=test, comm=comm) # generated is a dictionary containing; # 'mask': Variable((bs, num_kp+1, h/4, w/4)) when scale_factor=0.25 # 'sparse_deformed': Variable((bs, num_kp + 1, num_channel, h/4, w/4)) # 'occlusion_map': Variable((bs, 1, h/4, w/4)) # 'deformed': Variable((bs, c, h, w)) # 'prediction': Variable((bs, c, h, w)) Only this is fed to discriminator. generated["prediction"].persistent = True pyramide_real = get_image_pyramid(driving, train_params.scales, generated["prediction"].shape[1]) persistent_all(pyramide_real) pyramide_fake = get_image_pyramid(generated['prediction'], train_params.scales, generated["prediction"].shape[1]) persistent_all(pyramide_fake) total_loss_G = None # dammy. defined temporarily loss_var_dict = {} # perceptual loss using VGG19 (always applied) if loss_flags.use_perceptual_loss: logger.info("Use Perceptual Loss.") scales = train_params.scales weights = train_params.loss_weights.perceptual vgg_param_path = train_params.vgg_param_path percep_loss = perceptual_loss(pyramide_real, pyramide_fake, scales, weights, vgg_param_path) percep_loss.persistent = True loss_var_dict['perceptual_loss'] = percep_loss total_loss_G = percep_loss # (LS)GAN loss and feature matching loss if loss_flags.use_gan_loss: logger.info("Use GAN Loss.") with nn.parameter_scope("discriminator"): discriminator_maps_generated = multiscale_discriminator( pyramide_fake, kp=unlink_all(kp_driving), **model_params.discriminator_params, **model_params.common_params, test=test, comm=comm) discriminator_maps_real = multiscale_discriminator( pyramide_real, kp=unlink_all(kp_driving), **model_params.discriminator_params, **model_params.common_params, test=test, comm=comm) for v in discriminator_maps_generated["feature_maps_1"]: v.persistent = True discriminator_maps_generated["prediction_map_1"].persistent = True for v in discriminator_maps_real["feature_maps_1"]: v.persistent = True discriminator_maps_real["prediction_map_1"].persistent = True for i, scale in enumerate(model_params.discriminator_params.scales): key = f'prediction_map_{scale}'.replace('.', '-') lsgan_loss_weight = train_params.loss_weights.generator_gan # LSGAN loss for Generator if i == 0: gan_loss_gen = lsgan_loss(discriminator_maps_generated[key], lsgan_loss_weight) else: gan_loss_gen += lsgan_loss(discriminator_maps_generated[key], lsgan_loss_weight) # LSGAN loss for Discriminator if i == 0: gan_loss_dis = lsgan_loss(discriminator_maps_real[key], lsgan_loss_weight, discriminator_maps_generated[key]) else: gan_loss_dis += lsgan_loss(discriminator_maps_real[key], lsgan_loss_weight, discriminator_maps_generated[key]) gan_loss_dis.persistent = True loss_var_dict['gan_loss_dis'] = gan_loss_dis total_loss_D = gan_loss_dis total_loss_D.persistent = True gan_loss_gen.persistent = True loss_var_dict['gan_loss_gen'] = gan_loss_gen total_loss_G += gan_loss_gen if loss_flags.use_feature_matching_loss: logger.info("Use Feature Matching Loss.") fm_weights = train_params.loss_weights.feature_matching fm_loss = feature_matching_loss(discriminator_maps_real, discriminator_maps_generated, model_params, fm_weights) fm_loss.persistent = True loss_var_dict['feature_matching_loss'] = fm_loss total_loss_G += fm_loss # transform loss if loss_flags.use_equivariance_value_loss or loss_flags.use_equivariance_jacobian_loss: transform = Transform(bs, **config.train_params.transform_params) transformed_frame = transform.transform_frame(driving) with nn.parameter_scope("kp_detector"): transformed_kp = detect_keypoint(transformed_frame, **model_params.kp_detector_params, **model_params.common_params, test=test, comm=comm) persistent_all(transformed_kp) # Value loss part if loss_flags.use_equivariance_value_loss: logger.info("Use Equivariance Value Loss.") warped_kp_value = transform.warp_coordinates( transformed_kp['value']) eq_value_weight = train_params.loss_weights.equivariance_value eq_value_loss = equivariance_value_loss(kp_driving['value'], warped_kp_value, eq_value_weight) eq_value_loss.persistent = True loss_var_dict['equivariance_value_loss'] = eq_value_loss total_loss_G += eq_value_loss # jacobian loss part if loss_flags.use_equivariance_jacobian_loss: logger.info("Use Equivariance Jacobian Loss.") arithmetic_jacobian = transform.jacobian(transformed_kp['value']) eq_jac_weight = train_params.loss_weights.equivariance_jacobian eq_jac_loss = equivariance_jacobian_loss( kp_driving['jacobian'], arithmetic_jacobian, transformed_kp['jacobian'], eq_jac_weight) eq_jac_loss.persistent = True loss_var_dict['equivariance_jacobian_loss'] = eq_jac_loss total_loss_G += eq_jac_loss assert total_loss_G is not None total_loss_G.persistent = True loss_var_dict['total_loss_gen'] = total_loss_G # -------------------- Create Monitors -------------------- monitors_gen, monitors_dis, monitor_time, monitor_vis, log_dir = get_monitors( config, loss_flags, loss_var_dict) if device_id == 0: # Dump training info .yaml _ = shutil.copy(args.config, log_dir) # copy the config yaml training_info_yaml = os.path.join(log_dir, "training_info.yaml") os.rename(os.path.join(log_dir, os.path.basename(args.config)), training_info_yaml) # then add additional information with open(training_info_yaml, "a", encoding="utf-8") as f: f.write(f"\nlog_dir: {log_dir}\nsaved_parameter: None") # -------------------- Solver Setup -------------------- solvers = setup_solvers(train_params) solver_generator = solvers["generator"] solver_discriminator = solvers["discriminator"] solver_kp_detector = solvers["kp_detector"] # max epochs num_epochs = train_params['num_epochs'] # iteration per epoch num_iter_per_epoch = data_iterator.size // bs # will be increased by num_repeat if 'num_repeats' in train_params or train_params['num_repeats'] != 1: num_iter_per_epoch *= config.train_params.num_repeats # modify learning rate if current epoch exceeds the number defined in lr_decay_at_epochs = train_params['epoch_milestones'] # ex. [60, 90] gamma = 0.1 # decay rate # -------------------- For finetuning --------------------- if args.ft_params: assert os.path.isfile(args.ft_params) logger.info(f"load {args.ft_params} for finetuning.") nn.load_parameters(args.ft_params) start_epoch = int( os.path.splitext(os.path.basename( args.ft_params))[0].split("epoch_")[1]) # set solver's state for name, solver in solvers.items(): saved_states = os.path.join( os.path.dirname(args.ft_params), f"state_{name}_at_epoch_{start_epoch}.h5") solver.load_states(saved_states) start_epoch += 1 logger.info(f"Resuming from epoch {start_epoch}.") logger.info( f"Start training. Total epoch: {num_epochs - start_epoch}, {num_iter_per_epoch * n_devices} iter/epoch." ) for e in range(start_epoch, num_epochs): logger.info(f"Epoch: {e} / {num_epochs}.") data_iterator._reset() # rewind the iterator at the beginning # learning rate scheduler if e in lr_decay_at_epochs: logger.info("Learning rate decayed.") learning_rate_decay(solvers, gamma=gamma) for i in range(num_iter_per_epoch): _driving, _source = data_iterator.next() source.d = _source driving.d = _driving # update generator and keypoint detector total_loss_G.forward() if device_id == 0: monitors_gen.add((e * num_iter_per_epoch + i) * n_devices) solver_generator.zero_grad() solver_kp_detector.zero_grad() callback = None if n_devices > 1: params = [x.grad for x in solver_generator.get_parameters().values()] + \ [x.grad for x in solver_kp_detector.get_parameters().values()] callback = comm.all_reduce_callback(params, 2 << 20) total_loss_G.backward(clear_buffer=True, communicator_callbacks=callback) solver_generator.update() solver_kp_detector.update() if loss_flags.use_gan_loss: # update discriminator total_loss_D.forward(clear_no_need_grad=True) if device_id == 0: monitors_dis.add((e * num_iter_per_epoch + i) * n_devices) solver_discriminator.zero_grad() callback = None if n_devices > 1: params = [ x.grad for x in solver_discriminator.get_parameters().values() ] callback = comm.all_reduce_callback(params, 2 << 20) total_loss_D.backward(clear_buffer=True, communicator_callbacks=callback) solver_discriminator.update() if device_id == 0: monitor_time.add((e * num_iter_per_epoch + i) * n_devices) if device_id == 0 and ( (e * num_iter_per_epoch + i) * n_devices) % config.monitor_params.visualize_freq == 0: images_to_visualize = [ source.d, driving.d, generated["prediction"].d ] visuals = combine_images(images_to_visualize) monitor_vis.add((e * num_iter_per_epoch + i) * n_devices, visuals) if device_id == 0: if e % train_params.checkpoint_freq == 0 or e == num_epochs - 1: save_parameters(e, log_dir, solvers) return