Beispiel #1
0
def main():
    args = get_args()

    # Setup
    from nnabla.ext_utils import get_extension_context
    if args.context is None:
        print(
            'Computation backend is not specified. Using the default "cudnn".')
        extension_module = "cudnn"
    else:
        extension_module = args.context
    ctx = get_extension_context(extension_module,
                                device_id=args.device_id,
                                type_config=args.type_config)
    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(comm.ctx)

    if args.raise_dataset_size:
        imagenet_val_size = 50000
        if imagenet_val_size % (comm.n_procs * args.batch_size) != 0:
            raise ValueError(
                f'The batchsize and number of workers must be set so that {imagenet_val_size} can be divisible by (batch_size * num_workers).'
            )

    # Load parameters
    channel_last, channels = load_parameters_and_config(
        args.weights, args.type_config)
    args.channel_last = channel_last

    # Build a validation network
    from models import build_network
    num_classes = args.num_classes
    # Network for validation
    v_model = get_model(args,
                        num_classes,
                        test=True,
                        channel_last=channel_last,
                        spatial_size=args.spatial_size,
                        channels=channels)

    vdata = get_val_data_iterator(args, comm, channels, args.spatial_size,
                                  args.norm_config)

    from nnabla_ext.cuda import StreamEventHandler
    stream_event_handler = StreamEventHandler(int(comm.ctx.device_id))

    # Monitors
    import nnabla.monitor as M
    import os
    monitor = None
    if comm.rank == 0:
        if not os.path.isdir(args.monitor_path):
            os.makedirs(args.monitor_path)
        monitor = M.Monitor(args.monitor_path)

    from utils import EpochValidator
    EpochValidator(v_model, vdata, comm, monitor, stream_event_handler).run(0)
Beispiel #2
0
def run(args):
    """Runs the algorithm."""
    Path(hp.output_path).mkdir(parents=True, exist_ok=True)

    # setup nnabla context
    ctx = get_extension_context(args.context, device_id='0')
    nn.set_default_context(ctx)
    hp.comm = CommunicatorWrapper(ctx)
    hp.event = StreamEventHandler(int(hp.comm.ctx.device_id))

    if hp.comm.n_procs > 1 and hp.comm.rank == 0:
        n_procs = hp.comm.n_procs
        logger.info(f'Distributed training with {n_procs} processes.')

    rng = np.random.RandomState(hp.seed)

    # setup optimizer
    lr_scheduler = NoamScheduler(hp.alpha, warmup=hp.warmup)
    optimizer = Optimizer(weight_decay=hp.weight_decay,
                          max_norm=hp.max_norm,
                          lr_scheduler=lr_scheduler,
                          name='Adam',
                          alpha=hp.alpha)

    # train data
    train_loader = data_iterator(LJSpeechDataSource('metadata_train.csv',
                                                    hp,
                                                    shuffle=True,
                                                    rng=rng),
                                 batch_size=hp.batch_size,
                                 with_memory_cache=False)
    # valid data
    valid_loader = data_iterator(LJSpeechDataSource('metadata_valid.csv',
                                                    hp,
                                                    shuffle=False,
                                                    rng=rng),
                                 batch_size=hp.batch_size,
                                 with_memory_cache=False)
    dataloader = dict(train=train_loader, valid=valid_loader)
    model = Tacotron(hp)

    TacotronTrainer(model, dataloader, optimizer, hp).run()
Beispiel #3
0
def main():
    conf = get_config()
    extension_module = conf.nnabla_context.context
    ctx = get_extension_context(extension_module,
                                device_id=conf.nnabla_context.device_id)
    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(comm.ctx)
    print("#GPU Count: ", comm.n_procs)

    data_iterator_train = jsi_iterator(conf.batch_size, conf, train=True)
    if conf.scaling_factor == 1:
        d_t = nn.Variable((conf.batch_size, 80, 80, 3), need_grad=True)
        l_t = nn.Variable((conf.batch_size, 80, 80, 3), need_grad=True)

    else:
        d_t = nn.Variable((conf.batch_size, 160 / conf.scaling_factor,
                           160 / conf.scaling_factor, 3),
                          need_grad=True)
        l_t = nn.Variable((conf.batch_size, 160, 160, 3), need_grad=True)

    if comm.n_procs > 1:
        data_iterator_train = data_iterator_train.slice(
            rng=None, num_of_slices=comm.n_procs, slice_pos=comm.rank)

    monitor_path = './nnmonitor' + \
        str(datetime.datetime.now().strftime("%Y%m%d%H%M%S"))

    monitor = Monitor(monitor_path)
    jsi_monitor = setup_monitor(conf, monitor)

    with nn.parameter_scope("jsinet"):
        nn.load_parameters(conf.pre_trained_model)
        net = model(d_t, conf.scaling_factor)
        net.pred.persistent = True
    rec_loss = F.mean(F.squared_error(net.pred, l_t))
    rec_loss.persistent = True
    g_final_loss = rec_loss

    if conf.jsigan:
        net_gan = gan_model(l_t, net.pred, conf)
        d_final_fm_loss = net_gan.d_adv_loss
        d_final_fm_loss.persistent = True
        d_final_detail_loss = net_gan.d_detail_adv_loss
        d_final_detail_loss.persistent = True
        g_final_loss = conf.rec_lambda * rec_loss + conf.adv_lambda * (
            net_gan.g_adv_loss + net_gan.g_detail_adv_loss
        ) + conf.fm_lambda * (net_gan.fm_loss + net_gan.fm_detail_loss)
        g_final_loss.persistent = True

    max_iter = data_iterator_train._size // (conf.batch_size)
    if comm.rank == 0:
        print("max_iter", data_iterator_train._size, max_iter)

    iteration = 0
    if not conf.jsigan:
        start_epoch = 0
        end_epoch = conf.adv_weight_point
        lr = conf.learning_rate * comm.n_procs
    else:
        start_epoch = conf.adv_weight_point
        end_epoch = conf.epoch
        lr = conf.learning_rate * comm.n_procs
        w_d = conf.weight_decay * comm.n_procs

    # Set generator parameters
    with nn.parameter_scope("jsinet"):
        solver_jsinet = S.Adam(alpha=lr, beta1=0.9, beta2=0.999, eps=1e-08)
        solver_jsinet.set_parameters(nn.get_parameters())

    if conf.jsigan:
        solver_disc_fm = S.Adam(alpha=lr, beta1=0.9, beta2=0.999, eps=1e-08)
        solver_disc_detail = S.Adam(alpha=lr,
                                    beta1=0.9,
                                    beta2=0.999,
                                    eps=1e-08)
        with nn.parameter_scope("Discriminator_FM"):
            solver_disc_fm.set_parameters(nn.get_parameters())
        with nn.parameter_scope("Discriminator_Detail"):
            solver_disc_detail.set_parameters(nn.get_parameters())

    for epoch in range(start_epoch, end_epoch):
        for index in range(max_iter):
            d_t.d, l_t.d = data_iterator_train.next()

            if not conf.jsigan:
                # JSI-net -> Generator
                lr_stair_decay_points = [200, 225]
                lr_net = get_learning_rate(lr, iteration,
                                           lr_stair_decay_points,
                                           conf.lr_decreasing_factor)
                g_final_loss.forward(clear_no_need_grad=True)
                solver_jsinet.zero_grad()
                if comm.n_procs > 1:
                    all_reduce_callback = comm.get_all_reduce_callback()
                    g_final_loss.backward(
                        clear_buffer=True,
                        communicator_callbacks=all_reduce_callback)
                else:
                    g_final_loss.backward(clear_buffer=True)
                solver_jsinet.set_learning_rate(lr_net)
                solver_jsinet.update()
            else:
                # GAN part (discriminator + generator)
                lr_gan = lr if epoch < conf.gan_lr_linear_decay_point \
                    else lr * (end_epoch - epoch) / (end_epoch - conf.gan_lr_linear_decay_point)
                lr_gan = lr_gan * conf.gan_ratio

                net.pred.need_grad = False

                # Discriminator_FM
                solver_disc_fm.zero_grad()
                d_final_fm_loss.forward(clear_no_need_grad=True)
                if comm.n_procs > 1:
                    all_reduce_callback = comm.get_all_reduce_callback()
                    d_final_fm_loss.backward(
                        clear_buffer=True,
                        communicator_callbacks=all_reduce_callback)
                else:
                    d_final_fm_loss.backward(clear_buffer=True)
                solver_disc_fm.set_learning_rate(lr_gan)
                solver_disc_fm.weight_decay(w_d)
                solver_disc_fm.update()

                # Discriminator_Detail
                solver_disc_detail.zero_grad()
                d_final_detail_loss.forward(clear_no_need_grad=True)
                if comm.n_procs > 1:
                    all_reduce_callback = comm.get_all_reduce_callback()
                    d_final_detail_loss.backward(
                        clear_buffer=True,
                        communicator_callbacks=all_reduce_callback)
                else:
                    d_final_detail_loss.backward(clear_buffer=True)
                solver_disc_detail.set_learning_rate(lr_gan)
                solver_disc_detail.weight_decay(w_d)
                solver_disc_detail.update()

                # Generator
                net.pred.need_grad = True
                solver_jsinet.zero_grad()
                g_final_loss.forward(clear_no_need_grad=True)
                if comm.n_procs > 1:
                    all_reduce_callback = comm.get_all_reduce_callback()
                    g_final_loss.backward(
                        clear_buffer=True,
                        communicator_callbacks=all_reduce_callback)
                else:
                    g_final_loss.backward(clear_buffer=True)
                solver_jsinet.set_learning_rate(lr_gan)
                solver_jsinet.update()

            iteration += 1
            if comm.rank == 0:
                train_psnr = compute_psnr(net.pred.d, l_t.d, 1.)
                jsi_monitor['psnr'].add(iteration, train_psnr)
                jsi_monitor['rec_loss'].add(iteration, rec_loss.d.copy())
                jsi_monitor['time'].add(iteration)

            if comm.rank == 0:
                if conf.jsigan:
                    jsi_monitor['g_final_loss'].add(iteration,
                                                    g_final_loss.d.copy())
                    jsi_monitor['g_adv_loss'].add(iteration,
                                                  net_gan.g_adv_loss.d.copy())
                    jsi_monitor['g_detail_adv_loss'].add(
                        iteration, net_gan.g_detail_adv_loss.d.copy())
                    jsi_monitor['d_final_fm_loss'].add(
                        iteration, d_final_fm_loss.d.copy())
                    jsi_monitor['d_final_detail_loss'].add(
                        iteration, d_final_detail_loss.d.copy())
                    jsi_monitor['fm_loss'].add(iteration,
                                               net_gan.fm_loss.d.copy())
                    jsi_monitor['fm_detail_loss'].add(
                        iteration, net_gan.fm_detail_loss.d.copy())
                    jsi_monitor['lr'].add(iteration, lr_gan)

        if comm.rank == 0:
            if not os.path.exists(conf.output_dir):
                os.makedirs(conf.output_dir)
            with nn.parameter_scope("jsinet"):
                nn.save_parameters(
                    os.path.join(conf.output_dir,
                                 "model_param_%04d.h5" % epoch))
def train():
    """
    Main script for training.
    """
    args, train_config = get_train_args()

    num_classes = args.num_classes

    # Communicator and Context
    from nnabla.ext_utils import get_extension_context
    extension_module = "cudnn"  # TODO: Hard coded!!!
    ctx = get_extension_context(extension_module,
                                device_id=args.device_id,
                                type_config=args.type_config)
    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(comm.ctx)
    # To utilize TensorCore in FP16
    channels = 4 if args.type_config == 'half' else 3

    from nnabla_ext.cuda import StreamEventHandler
    stream_event_handler = StreamEventHandler(int(comm.ctx.device_id))

    # Create data iterater
    data, vdata = get_data_iterators(args, comm, channels)

    # Create mixup object
    mixup = create_mixup_or_none(train_config.mixup, num_classes, comm)

    # Network for training
    t_model = get_model(args,
                        num_classes,
                        test=False,
                        channel_last=args.channel_last,
                        mixup=mixup,
                        channels=channels,
                        label_smoothing=train_config.label_smoothing,
                        ctx_for_loss=comm.ctx_float)

    # Network for validation
    v_model = get_model(args,
                        num_classes,
                        test=True,
                        channel_last=args.channel_last,
                        channels=channels)

    # Solver
    # lr will be set later
    solver = MomentumNoWeightDecayBn(1, train_config.momentum)
    solver.set_parameters(nn.get_parameters())

    # Learning rate scheduler
    learning_rate_scheduler = create_learning_rate_scheduler(train_config)

    # Monitors
    monitor = None
    if comm.rank == 0:
        if not os.path.isdir(args.monitor_path):
            os.makedirs(args.monitor_path)
        monitor = M.Monitor(args.monitor_path)

    # Epoch runner
    loss_scaling = train_config.loss_scaling if args.type_config == 'half' else 1
    train_epoch = EpochTrainer(t_model, solver, learning_rate_scheduler, data,
                               comm, monitor, loss_scaling,
                               train_config.weight_decay, stream_event_handler,
                               mixup)
    val_epoch = None
    if args.val_interval > 0:
        val_epoch = EpochValidator(v_model, vdata, comm, monitor,
                                   stream_event_handler)

    # Epoch loop
    for epoch in range(train_config.epochs):
        # Save parameters
        if epoch > 0 and epoch % (
                args.model_save_interval) == 0 and comm.rank == 0:
            nn.save_parameters(
                os.path.join(args.monitor_path, 'param_%03d.h5' % epoch))

        # Run validation for examples in an epoch
        if val_epoch is not None \
           and epoch > 0 \
           and epoch % args.val_interval == 0:
            val_epoch.run(epoch)

        # Run training for examples in an epoch
        train_epoch.run(epoch)

    # Run final validation
    if val_epoch is not None:
        val_epoch.run(train_config.epochs)

    # Save the final model.
    if comm.rank == 0:
        nn.save_parameters(
            os.path.join(args.monitor_path,
                         'param_%03d.h5' % (train_config.epochs)))
Beispiel #5
0
def main():
    """
    main - driver code to run training for Zooming SloMo
    """
    # Check NNabla version
    if get_nnabla_version_integer() < 11700:
        raise ValueError(
            'This does not work with nnabla version less than v1.17.0 since deformable_conv layer is added in v1.17.0 . Please update the nnabla version.'
        )

    conf = get_config()
    extension_module = conf.nnabla_context.context
    ctx = get_extension_context(extension_module,
                                device_id=conf.nnabla_context.device_id)
    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(comm.ctx)
    print("comm rank", comm.rank)

    # change max_iter, learning_rate and cosine_period when batch-size or no. of gpu devices change.
    default_batch_size = 12
    train_scale_factor = comm.n_procs * \
        (conf.train.batch_size / default_batch_size)
    max_iter = int(conf.train.max_iter // train_scale_factor)
    learning_rate = conf.train.learning_rate * \
        (conf.train.batch_size / default_batch_size)
    cosine_period = int(conf.train.cosine_period // train_scale_factor)

    # for single-GPU training
    data_iterator_train = data_iterator(conf, shuffle=True)

    # for multi-GPU training
    if comm.n_procs > 1:
        data_iterator_train = data_iterator_train.slice(
            rng=None, num_of_slices=comm.n_procs, slice_pos=comm.rank)

    # LR-LFR data for ZoomingSloMo input
    data_lr_lfr = nn.Variable(
        (conf.train.batch_size, (conf.data.n_frames // 2) + 1, 3,
         conf.data.lr_size, conf.data.lr_size))

    # HR-HFR data for ZoomingSloMo ground truth
    data_gt = nn.Variable((conf.train.batch_size, conf.data.n_frames, 3,
                           conf.data.gt_size, conf.data.gt_size))

    if conf.train.only_slomo:
        '''
        High resolution data as input to only-Slomo network for frame interpolation,
        hence we use lesser number of frames.
        '''
        # LFR data for SloMo input,
        slomo_gt = data_gt
        input_to_slomo = slomo_gt[:, 0:conf.data.n_frames:2, :, :, :]

    # setting up monitors for logging
    monitor_path = './nnmonitor'
    monitor = Monitor(monitor_path)
    monitor_loss = MonitorSeries('loss',
                                 monitor,
                                 interval=conf.train.monitor_log_freq)
    monitor_lr = MonitorSeries('learning rate',
                               monitor,
                               interval=conf.train.monitor_log_freq)
    monitor_time = MonitorTimeElapsed("training time per iteration",
                                      monitor,
                                      interval=conf.train.monitor_log_freq)

    scope_name = "ZoomingSloMo" if not conf.train.only_slomo else "SloMo"

    with nn.parameter_scope(scope_name):
        if conf.train.only_slomo:
            generated_frame = zooming_slo_mo_network(input_to_slomo,
                                                     conf.train.only_slomo)
            diff = generated_frame - slomo_gt
        else:
            generated_frame = zooming_slo_mo_network(data_lr_lfr,
                                                     conf.train.only_slomo)
            diff = generated_frame - data_gt

    # Charbonnier loss
    loss = F.sum((diff * diff + conf.train.eps)**0.5)

    # Define optimizer
    solver = S.Adam(alpha=learning_rate,
                    beta1=conf.train.beta1,
                    beta2=conf.train.beta2)

    # Set Parameters
    with nn.parameter_scope(scope_name):
        solver.set_parameters(nn.get_parameters())

    solver_dict = {scope_name: solver}

    if comm.rank == 0:
        print("maximum iterations", max_iter)

    start_point = 0
    if conf.train.checkpoint:
        # Load optimizer/solver information and model weights from checkpoint
        print("Loading weights from checkpoint:", conf.train.checkpoint)
        with nn.parameter_scope(scope_name):
            start_point = load_checkpoint(conf.train.checkpoint, solver_dict)

    if not os.path.isdir(conf.data.output_dir):
        os.makedirs(conf.data.output_dir)

    # Training loop.
    for i in range(start_point, max_iter):
        # Get Training Data
        if conf.train.only_slomo:
            _, data_gt.d = data_iterator_train.next()
        else:
            data_lr_lfr.d, data_gt.d = data_iterator_train.next()
        l_rate = get_repeated_cosine_annealing_learning_rate(
            i, learning_rate, conf.train.eta_min, cosine_period,
            conf.train.cosine_num_period)

        # Update
        solver.zero_grad()
        solver.set_learning_rate(l_rate)
        loss.forward(clear_no_need_grad=True)
        if comm.n_procs > 1:
            all_reduce_callback = comm.get_all_reduce_callback()
            loss.backward(clear_buffer=True,
                          communicator_callbacks=all_reduce_callback)
        else:
            loss.backward(clear_buffer=True)
        solver.update()

        if comm.rank == 0:
            monitor_loss.add(i, loss.d.copy())
            monitor_lr.add(i, l_rate)
            monitor_time.add(i)
            if (i % conf.train.save_checkpoint_freq) == 0:
                # Save intermediate check_points
                with nn.parameter_scope(scope_name):
                    save_checkpoint(conf.data.output_dir, i, solver_dict)

    # Save final model parameters
    if comm.rank == 0:
        with nn.parameter_scope(scope_name):
            nn.save_parameters(
                os.path.join(conf.data.output_dir, "final_model.h5"))
Beispiel #6
0
def main():
    conf = get_config()
    extension_module = conf.nnabla_context.context
    ctx = get_extension_context(
        extension_module, device_id=conf.nnabla_context.device_id)
    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(comm.ctx)
    print("comm rank", comm.rank)

    # data iterators for train and val data
    from data_loader import data_iterator_sr, get_sample_name_grid, nn_data_gauss_down_quad

    sample_names = get_sample_name_grid(conf)
    num_samples = len(sample_names[0])
    print("No of training samples :", num_samples)

    tar_size = conf.train.crop_size
    tar_size = (conf.train.crop_size * 4) + int(1.5 * 3.0) * \
        2  # crop_size * 4, and Gaussian blur margin

    data_iterator_train = data_iterator_sr(
        conf, num_samples, sample_names, tar_size, shuffle=True)

    if comm.n_procs > 1:
        data_iterator_train = data_iterator_train.slice(
            rng=None, num_of_slices=comm.n_procs, slice_pos=comm.rank)

    train_hr = nn.Variable(
        (conf.train.batch_size, conf.train.rnn_n, conf.train.crop_size*4, conf.train.crop_size*4, 3))
    data_hr = nn.Variable(
        (conf.train.batch_size, conf.train.rnn_n, tar_size, tar_size, 3))
    train_lr = nn_data_gauss_down_quad(data_hr.reshape(
        (conf.train.batch_size * conf.train.rnn_n, tar_size, tar_size, 3)))
    train_lr = F.reshape(
        train_lr, (conf.train.batch_size, conf.train.rnn_n, conf.train.crop_size, conf.train.crop_size, 3))

    # setting up monitors for logging
    monitor_path = './nnmonitor' + \
        str(datetime.datetime.now().strftime("%Y%m%d%H%M%S"))
    monitor = Monitor(monitor_path)
    common_monitor = get_common_monitors(monitor)

    # Change max_iter and learning_rate when batch size or no. of gpu devices change.
    div_factor = conf.train.batch_size * comm.n_procs
    max_iter = (conf.train.max_iter * 4) // div_factor
    learning_rate = conf.train.learning_rate * \
        (conf.train.batch_size / 4) * comm.n_procs

    if comm.rank == 0:
        print("maximum iterations", max_iter)

    scope_name = 'frvsr/'
    if conf.train.tecogan:
        scope_name = 'tecogan/'
        if not conf.train.checkpoint:
            print('loading pretrained FRVSR model',
                  conf.train.pre_trained_frvsr_weights)
            with nn.parameter_scope(scope_name):
                nn.load_parameters(conf.train.pre_trained_frvsr_weights)
                params_from_pre_trained_model = []
                for key, val in nn.get_parameters().items():
                    params_from_pre_trained_model.append(scope_name + key)

            network = get_tecogan_model(conf, train_lr, train_hr, scope_name)
            params_from_graph = nn.get_parameters()

            # Set the Generator parameters which are not in FRVSR to zero,
            # as done in orig implementation.
            for key, val in params_from_graph.items():
                if key in params_from_pre_trained_model or key.startswith('vgg') or key.startswith('disc'):
                    continue
                print(key)
                val.data.zero()  # fill with zero

        else:
            network = get_tecogan_model(conf, train_lr, train_hr, scope_name)

        # Define discriminator optimizer/solver
        solver_disc = S.Adam(alpha=learning_rate,
                             beta1=conf.train.beta, eps=conf.train.adameps)
        # Set discriminator Parameters
        with nn.parameter_scope("discriminator"):
            solver_disc.set_parameters(nn.get_parameters())

        # setting up monitors for TecoGAN
        tecogan_monitor = get_tecogan_monitors(monitor)

    else:
        network = get_frvsr_model(conf, train_lr, train_hr, scope_name)

    # Define generator and fnet optimizer/solver
    solver_gen = S.Adam(alpha=learning_rate,
                        beta1=conf.train.beta, eps=conf.train.adameps)
    solver_fnet = S.Adam(alpha=learning_rate,
                         beta1=conf.train.beta, eps=conf.train.adameps)

    # Set generator and fnet Parameters
    with nn.parameter_scope(scope_name + "generator"):
        solver_gen.set_parameters(nn.get_parameters())
    with nn.parameter_scope(scope_name + "fnet"):
        solver_fnet.set_parameters(nn.get_parameters())

    if conf.train.tecogan:
        solver_dict = {"gen": solver_gen,
                       "fnet": solver_fnet, "disc": solver_disc}
    else:
        solver_dict = {"gen": solver_gen, "fnet": solver_fnet}

    start_point = 0
    if conf.train.checkpoint:
        # Load optimizer/solver information and model weights from checkpoint
        start_point = load_checkpoint(conf.train.checkpoint, solver_dict)

    # Exponential Moving Average Calculation for tb
    ema = ExponentialMovingAverage(conf.train.decay)
    tb = 0

    # Create output directory if it doesn't exist
    if not os.path.exists(conf.data.output_dir):
        os.makedirs(conf.data.output_dir)

    # Training loop.
    for i in range(start_point, max_iter):
        # Get Training Data
        data_hr.d, train_hr.d = data_iterator_train.next()

        if conf.train.tecogan:
            network.t_discrim_loss.forward(clear_no_need_grad=True)
            if np.less(tb, 0.4):  # train gen with d
                # Compute grads for discriminator and update
                solver_disc.zero_grad()
                # Stop back-propagation from t_discrim_loss to generator
                network.t_gen_output.need_grad = False
                if comm.n_procs > 1:
                    all_reduce_callback = comm.get_all_reduce_callback()
                    network.t_discrim_loss.backward(clear_buffer=True,
                                                    communicator_callbacks=all_reduce_callback)
                else:
                    network.t_discrim_loss.backward(clear_buffer=True)
                solver_disc.update()  # Update grads
                # Enable back propagation from fnet_loss to Generator
                network.t_gen_output.need_grad = True

        # Compute grads for fnet and generator together using fnet_loss
        solver_fnet.zero_grad()
        solver_gen.zero_grad()
        # Apply forward and backward propagation on fnet_loss
        network.fnet_loss.forward(clear_no_need_grad=True)
        if comm.n_procs > 1:
            all_reduce_callback = comm.get_all_reduce_callback()
            network.fnet_loss.backward(clear_buffer=True,
                                       communicator_callbacks=all_reduce_callback)
        else:
            network.fnet_loss.backward(clear_buffer=True)
        # Update grads for fnet and generator
        solver_gen.update()
        solver_fnet.update()

        if conf.train.tecogan:
            if comm.n_procs > 1:
                comm.all_reduce([network.t_discrim_real_loss.data,
                                 network.t_adversarial_loss.data], division=True, inplace=True)
            t_balance = F.mean(network.t_discrim_real_loss.data) + \
                network.t_adversarial_loss.data
            if i == 0:
                ema.register(t_balance)
            else:
                tb = ema(t_balance)
            if comm.rank == 0:
                tecogan_monitor.monitor_pp_loss.add(
                    i, network.pp_loss.d.copy())
                tecogan_monitor.monitor_vgg_loss.add(
                    i, network.vgg_loss.d.copy())
                tecogan_monitor.monitor_sum_layer_loss.add(
                    i, network.sum_layer_loss.d.copy())
                tecogan_monitor.monitor_adv_loss.add(
                    i, network.t_adversarial_loss.d.copy())
                tecogan_monitor.monitor_disc_loss.add(
                    i, network.t_discrim_loss.d.copy())
                tecogan_monitor.monitor_tb.add(i, tb)

        if comm.rank == 0:
            common_monitor.monitor_content_loss.add(
                i, network.content_loss.d.copy())
            common_monitor.monitor_gen_loss.add(i, network.gen_loss.d.copy())
            common_monitor.monitor_warp_loss.add(i, network.warp_loss.d.copy())
            common_monitor.monitor_lr.add(i, learning_rate)
            common_monitor.monitor_time.add(i)
            if (i % conf.train.save_freq) == 0:
                # Save intermediate model parameters
                with nn.parameter_scope(scope_name):
                    nn.save_parameters(os.path.join(
                        conf.data.output_dir, "model_param_%08d.h5" % i))

                # Save intermediate check_points
                save_checkpoint(conf.data.output_dir, i, solver_dict)

    # save final Generator and Fnet network parameters
    if comm.rank == 0:
        with nn.parameter_scope(scope_name):
            nn.save_parameters(os.path.join(
                conf.data.output_dir, "model_param_%08d.h5" % i))
Beispiel #7
0
def train():
    """
    Main script for training.
    """
    args, train_config = get_train_args()

    num_classes = args.num_classes

    # Communicator and Context
    from nnabla.ext_utils import get_extension_context
    extension_module = "cudnn"  # TODO: Hard coded!!!
    ctx = get_extension_context(extension_module,
                                device_id=args.device_id,
                                type_config=args.type_config)
    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(comm.ctx)
    # To utilize TensorCore in FP16
    channels = 4 if args.type_config == 'half' else 3

    from nnabla_ext.cuda import StreamEventHandler
    stream_event_handler = StreamEventHandler(int(comm.ctx.device_id))

    # Create data iterater
    data, vdata = get_data_iterators(args, comm, channels, args.spatial_size)

    # Create mixup object
    mixup = create_mixup_or_none(train_config.mixup, num_classes, comm)

    # Load model for fine-tuning
    if args.finetune:
        assert args.model_load_path is not None, "`--model-load-path` must be set in finetuning mode."
        if comm.rank == 0:
            logger.info(f'Loading parameter file `{args.model_load_path}.`')
            logger.info(
                "NOTE: It doesn't verify the compatibility between the parameter file and the architecture you choose."
            )
        nn.load_parameters(args.model_load_path)
        # String assumption that the last two paramters is the classification layer.
        param_keys = list(nn.get_parameters().keys())
        bkey = param_keys[-1]
        wkey = param_keys[-2]
        if comm.rank == 0:
            logger.info(
                f'Removing the last two parameter for fine tuning under an assumption that those correspond to the final affine layer parameters; `{wkey}` and `{bkey}`.'
            )
        nn.parameter.pop_parameter(wkey)
        nn.parameter.pop_parameter(bkey)

    # Network for training
    t_model = get_model(args,
                        num_classes,
                        test=False,
                        channel_last=args.channel_last,
                        mixup=mixup,
                        channels=channels,
                        spatial_size=args.spatial_size,
                        label_smoothing=train_config.label_smoothing,
                        ctx_for_loss=comm.ctx_float)

    # Network for validation
    v_model = get_model(args,
                        num_classes,
                        test=True,
                        channel_last=args.channel_last,
                        spatial_size=args.spatial_size,
                        channels=channels)

    # Solver
    # lr will be set later
    solver = MomentumNoWeightDecayBn(1, train_config.momentum)
    solver.set_parameters(nn.get_parameters())

    # Learning rate scheduler
    learning_rate_scheduler = create_learning_rate_scheduler(train_config)

    # Monitors
    monitor = None
    if comm.rank == 0:
        if not os.path.isdir(args.monitor_path):
            os.makedirs(args.monitor_path)
        monitor = M.Monitor(args.monitor_path)
        save_args(args, train_config)

    # Epoch runner
    loss_scaling = train_config.loss_scaling if args.type_config == 'half' else 1
    train_epoch = EpochTrainer(t_model, solver, learning_rate_scheduler, data,
                               comm, monitor, loss_scaling,
                               train_config.weight_decay, stream_event_handler,
                               mixup)
    val_epoch = None
    if args.val_interval > 0:
        val_epoch = EpochValidator(v_model, vdata, comm, monitor,
                                   stream_event_handler)

    # Epoch loop
    for epoch in range(train_config.epochs):
        # Save parameters
        if epoch > 0 and epoch % (
                args.model_save_interval) == 0 and comm.rank == 0:
            nn.save_parameters(
                os.path.join(args.monitor_path, 'param_%03d.h5' % epoch))

        # Run validation for examples in an epoch
        if val_epoch is not None \
           and epoch > 0 \
           and epoch % args.val_interval == 0:
            val_epoch.run(epoch)

        # Run training for examples in an epoch
        train_epoch.run(epoch)

    # Run final validation
    if val_epoch is not None:
        val_epoch.run(train_config.epochs)

    # Save the final model.
    if comm.rank == 0:
        nn.save_parameters(
            os.path.join(args.monitor_path,
                         'param_%03d.h5' % (train_config.epochs)))
Beispiel #8
0
def train():
    """
    Main script for training.
    """

    args = get_args()

    num_classes = 1000

    # Communicator and Context
    from nnabla.ext_utils import get_extension_context
    extension_module = "cudnn"  # TODO: Hard coded!!!
    ctx = get_extension_context(extension_module,
                                device_id=args.device_id,
                                type_config=args.type_config)
    comm = CommunicatorWrapper(ctx)
    nn.set_default_context(comm.ctx)

    from nnabla_ext.cuda import StreamEventHandler
    stream_event_handler = StreamEventHandler(int(comm.ctx.device_id))

    # Create data iterater
    data, vdata = get_data_iterators(args, comm, stream_event_handler)

    # Network for training
    t_model = get_model(args,
                        num_classes,
                        test=False,
                        channel_last=args.channel_last)

    # Network for validation
    v_model = get_model(args,
                        num_classes,
                        test=True,
                        channel_last=args.channel_last)

    # Solver
    loss_scaling = args.loss_scaling if args.type_config == 'half' else 1
    # To cancel loss scaling, learning rate is divided by loss_scaling.
    # Note this assumes legacy SGD w/ moemntum implementation,
    # otherwise, it is recommended to apply division at gradient itself
    # using scale_grad for example.
    base_learning_rate = args.learning_rate / loss_scaling

    # Weight decay is multiplied by loss_scaling to cancel the effect of loss_scaling
    # cancelling at learning rate.
    # Also, note that is is multiplied by number GPUs (processes),
    # because all-reduce sum over GPUs is performed before applying weight decay.
    weight_decay = args.weight_decay * loss_scaling * comm.n_procs
    solver = MomentumNoWeightDecayBn(base_learning_rate, 0.9)
    solver.set_parameters(nn.get_parameters())

    # Learning rate scheduler
    decay_rate = 0.1
    learning_rate_scheduler = LearningRateScheduler(
        base_learning_rate, args.learning_rate_decay_at, decay_rate,
        args.warmup_epochs)

    # Monitors
    monitor = None
    if comm.rank == 0:
        if not os.path.isdir(args.monitor_path):
            os.makedirs(args.monitor_path)
        monitor = M.Monitor(args.monitor_path)

    # Epoch runner
    train_epoch = EpochTrainer(t_model, solver, learning_rate_scheduler, data,
                               comm, monitor, loss_scaling, weight_decay,
                               stream_event_handler)
    val_epoch = None
    if args.val_interval > 0:
        val_epoch = EpochValidator(v_model, vdata, comm, monitor,
                                   stream_event_handler)

    # Epoch loop
    for epoch in range(args.max_epochs):
        # Save parameters
        if epoch > 0 and epoch % (
                args.model_save_interval) == 0 and comm.rank == 0:
            nn.save_parameters(
                os.path.join(args.monitor_path, 'param_%03d.h5' % epoch))

        # Run validation for examples in an epoch
        if val_epoch is not None \
           and epoch > 0 \
           and epoch % args.val_interval == 0:
            val_epoch.run(epoch)

        # Run training for examples in an epoch
        train_epoch.run(epoch)

    # Run final validation
    if val_epoch is not None:
        val_epoch.run(args.max_epochs)

    # Save the final model.
    if comm.rank == 0:
        nn.save_parameters(
            os.path.join(args.monitor_path,
                         'param_%03d.h5' % (args.max_epochs)))