Exemplo n.º 1
0
def _create_optimizer(ctx, o, networks, datasets):
    class Optimizer:
        pass

    optimizer = Optimizer()

    optimizer.comm = current_communicator()
    comm_size = optimizer.comm.size if optimizer.comm else 1
    optimizer.start_iter = (o.start_iter - 1) // comm_size + \
        1 if o.start_iter > 0 else 0
    optimizer.end_iter = (o.end_iter - 1) // comm_size + \
        1 if o.end_iter > 0 else 0
    optimizer.name = o.name
    optimizer.order = o.order
    optimizer.update_interval = o.update_interval if o.update_interval > 0 else 1
    optimizer.network = networks[o.network_name]
    optimizer.data_iterators = OrderedDict()
    for d in o.dataset_name:
        optimizer.data_iterators[d] = datasets[d].data_iterator

    optimizer.dataset_assign = OrderedDict()
    for d in o.data_variable:
        optimizer.dataset_assign[optimizer.network.variables[
            d.variable_name]] = d.data_name

    optimizer.generator_assign = OrderedDict()
    for g in o.generator_variable:
        optimizer.generator_assign[optimizer.network.variables[
            g.variable_name]] = _get_generator(g)

    optimizer.loss_variables = []
    for l in o.loss_variable:
        optimizer.loss_variables.append(
            optimizer.network.variables[l.variable_name])

    optimizer.parameter_learning_rate_multipliers = OrderedDict()
    for p in o.parameter_variable:
        param_variable_names = _get_matching_variable_names(
            p.variable_name, optimizer.network.variables.keys())
        for v_name in param_variable_names:
            optimizer.parameter_learning_rate_multipliers[
                optimizer.network.
                variables[v_name]] = p.learning_rate_multiplier

    with nn.context_scope(ctx):
        if o.solver.type == 'Adagrad':
            optimizer.solver = S.Adagrad(o.solver.adagrad_param.lr,
                                         o.solver.adagrad_param.eps)
            init_lr = o.solver.adagrad_param.lr
        elif o.solver.type == 'Adadelta':
            optimizer.solver = S.Adadelta(o.solver.adadelta_param.lr,
                                          o.solver.adadelta_param.decay,
                                          o.solver.adadelta_param.eps)
            init_lr = o.solver.adadelta_param.lr
        elif o.solver.type == 'Adam':
            optimizer.solver = S.Adam(o.solver.adam_param.alpha,
                                      o.solver.adam_param.beta1,
                                      o.solver.adam_param.beta2,
                                      o.solver.adam_param.eps)
            init_lr = o.solver.adam_param.alpha
        elif o.solver.type == 'Adamax':
            optimizer.solver = S.Adamax(o.solver.adamax_param.alpha,
                                        o.solver.adamax_param.beta1,
                                        o.solver.adamax_param.beta2,
                                        o.solver.adamax_param.eps)
            init_lr = o.solver.adamax_param.alpha
        elif o.solver.type == 'AdaBound':
            optimizer.solver = S.AdaBound(o.solver.adabound_param.alpha,
                                          o.solver.adabound_param.beta1,
                                          o.solver.adabound_param.beta2,
                                          o.solver.adabound_param.eps,
                                          o.solver.adabound_param.final_lr,
                                          o.solver.adabound_param.gamma)
            init_lr = o.solver.adabound_param.alpha
        elif o.solver.type == 'AMSGRAD':
            optimizer.solver = S.AMSGRAD(o.solver.amsgrad_param.alpha,
                                         o.solver.amsgrad_param.beta1,
                                         o.solver.amsgrad_param.beta2,
                                         o.solver.amsgrad_param.eps)
            init_lr = o.solver.amsgrad_param.alpha
        elif o.solver.type == 'AMSBound':
            optimizer.solver = S.AMSBound(o.solver.amsbound_param.alpha,
                                          o.solver.amsbound_param.beta1,
                                          o.solver.amsbound_param.beta2,
                                          o.solver.amsbound_param.eps,
                                          o.solver.amsbound_param.final_lr,
                                          o.solver.amsbound_param.gamma)
            init_lr = o.solver.amsbound_param.alpha
        elif o.solver.type == 'Eve':
            p = o.solver.eve_param
            optimizer.solver = S.Eve(p.alpha, p.beta1, p.beta2, p.beta3, p.k,
                                     p.k2, p.eps)
            init_lr = p.alpha
        elif o.solver.type == 'Momentum':
            optimizer.solver = S.Momentum(o.solver.momentum_param.lr,
                                          o.solver.momentum_param.momentum)
            init_lr = o.solver.momentum_param.lr
        elif o.solver.type == 'Nesterov':
            optimizer.solver = S.Nesterov(o.solver.nesterov_param.lr,
                                          o.solver.nesterov_param.momentum)
            init_lr = o.solver.nesterov_param.lr
        elif o.solver.type == 'RMSprop':
            optimizer.solver = S.RMSprop(o.solver.rmsprop_param.lr,
                                         o.solver.rmsprop_param.decay,
                                         o.solver.rmsprop_param.eps)
            init_lr = o.solver.rmsprop_param.lr
        elif o.solver.type == 'Sgd' or o.solver.type == 'SGD':
            optimizer.solver = S.Sgd(o.solver.sgd_param.lr)
            init_lr = o.solver.sgd_param.lr
        else:
            raise ValueError('Solver "' + o.solver.type +
                             '" is not supported.')

    parameters = {
        v.name: v.variable_instance
        for v, local_lr in
        optimizer.parameter_learning_rate_multipliers.items() if local_lr > 0.0
    }
    optimizer.solver.set_parameters(parameters)
    optimizer.parameters = OrderedDict(
        sorted(parameters.items(), key=lambda x: x[0]))

    optimizer.weight_decay = o.solver.weight_decay

    # keep following 2 lines for backward compatibility
    optimizer.lr_decay = o.solver.lr_decay if o.solver.lr_decay > 0.0 else 1.0
    optimizer.lr_decay_interval = o.solver.lr_decay_interval if o.solver.lr_decay_interval > 0 else 1
    optimizer.solver.set_states_from_protobuf(o)

    optimizer.comm = current_communicator()
    comm_size = optimizer.comm.size if optimizer.comm else 1
    optimizer.scheduler = ExponentialScheduler(init_lr, 1.0, 1)

    if o.solver.lr_scheduler_type == 'Polynomial':
        if o.solver.polynomial_scheduler_param.power != 0.0:
            optimizer.scheduler = PolynomialScheduler(
                init_lr,
                o.solver.polynomial_scheduler_param.max_iter // comm_size,
                o.solver.polynomial_scheduler_param.power)
    elif o.solver.lr_scheduler_type == 'Cosine':
        optimizer.scheduler = CosineScheduler(
            init_lr, o.solver.cosine_scheduler_param.max_iter // comm_size)
    elif o.solver.lr_scheduler_type == 'Exponential':
        if o.solver.exponential_scheduler_param.gamma != 1.0:
            optimizer.scheduler = ExponentialScheduler(
                init_lr, o.solver.exponential_scheduler_param.gamma,
                o.solver.exponential_scheduler_param.iter_interval //
                comm_size if
                o.solver.exponential_scheduler_param.iter_interval > comm_size
                else 1)
    elif o.solver.lr_scheduler_type == 'Step':
        if o.solver.step_scheduler_param.gamma != 1.0 and len(
                o.solver.step_scheduler_param.iter_steps) > 0:
            optimizer.scheduler = StepScheduler(
                init_lr, o.solver.step_scheduler_param.gamma, [
                    step // comm_size
                    for step in o.solver.step_scheduler_param.iter_steps
                ])
    elif o.solver.lr_scheduler_type == 'Custom':
        # ToDo
        raise NotImplementedError()
    elif o.solver.lr_scheduler_type == '':
        if o.solver.lr_decay_interval != 0 or o.solver.lr_decay != 0.0:
            optimizer.scheduler = ExponentialScheduler(
                init_lr, o.solver.lr_decay if o.solver.lr_decay > 0.0 else 1.0,
                o.solver.lr_decay_interval //
                comm_size if o.solver.lr_decay_interval > comm_size else 1)
    else:
        raise ValueError('Learning Rate Scheduler "' +
                         o.solver.lr_scheduler_type + '" is not supported.')

    if o.solver.lr_warmup_scheduler_type == 'Linear':
        if o.solver.linear_warmup_scheduler_param.warmup_iter >= comm_size:
            optimizer.scheduler = LinearWarmupScheduler(
                optimizer.scheduler,
                o.solver.linear_warmup_scheduler_param.warmup_iter //
                comm_size)

    optimizer.forward_sequence = optimizer.network.get_forward_sequence(
        optimizer.loss_variables)
    optimizer.backward_sequence = optimizer.network.get_backward_sequence(
        optimizer.loss_variables,
        optimizer.parameter_learning_rate_multipliers)

    return optimizer
Exemplo n.º 2
0
def train():
    """
    Main script.

    Steps:

    * Parse command line arguments.
    * Specify a context for computation.
    * Initialize DataIterator for CIFAR10.
    * Construct a computation graph for training and validation.
    * Initialize a solver and set parameter variables to it.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop on the training graph.
      * Compute training error
      * Set parameter gradients zero
      * Execute backprop.
      * Solver updates parameters by using gradients computed by backprop.
    """

    # define training parameters
    augmented_shift = True
    augmented_flip = True
    batch_size = 128
    vbatch_size = 100
    num_classes = 10
    weight_decay = 0.0002
    momentum = 0.9
    learning_rates = (cfg.initial_learning_rate,)*80 + \
        (cfg.initial_learning_rate / 10.,)*40 + \
        (cfg.initial_learning_rate / 100.,)*40
    print('lr={}'.format(learning_rates))
    print('weight_decay={}'.format(weight_decay))
    print('momentum={}'.format(momentum))

    # create nabla context
    from nnabla.ext_utils import get_extension_context
    ctx = get_extension_context('cudnn', device_id=args.gpu)
    nn.set_default_context(ctx)

    # Initialize DataIterator for CIFAR10.
    logger.info("Get CIFAR10 Data ...")
    data = cifar_data.DataIterator(batch_size,
                                   augmented_shift=augmented_shift,
                                   augmented_flip=augmented_flip)
    vdata = cifar_data.DataIterator(vbatch_size, val=True)

    if cfg.weightfile is not None:
        logger.info(f"Loading weights from {cfg.weightfile}")
        nn.load_parameters(cfg.weightfile)

    # TRAIN
    # Create input variables.
    image = nn.Variable([batch_size, 3, 32, 32])
    label = nn.Variable([batch_size, 1])

    # Create prediction graph.
    pred, hidden = resnet_cifar10(image,
                                  num_classes=num_classes,
                                  cfg=cfg,
                                  test=False)
    pred.persistent = True

    # Compute initial network size
    num_weights, kbytes_weights = network_size_weights()
    kbytes_weights.forward()
    print(f"Initial network size (weights) is {float(kbytes_weights.d):.3f}KB "
          f"(total number of weights: {int(num_weights):d}).")

    num_activations, kbytes_activations = network_size_activations()
    kbytes_activations.forward()
    print(
        f"Initial network size (activations) is {float(kbytes_activations.d):.3f}KB "
        f"(total number of activations: {int(num_activations):d}).")

    # Create loss function.
    cost_lambda2 = nn.Variable(())
    cost_lambda2.d = cfg.initial_cost_lambda2
    cost_lambda2.persistent = True
    cost_lambda3 = nn.Variable(())
    cost_lambda3.d = cfg.initial_cost_lambda3
    cost_lambda3.persistent = True

    loss1 = F.mean(F.softmax_cross_entropy(pred, label))
    loss1.persistent = True

    if cfg.target_weight_kbytes > 0:
        loss2 = F.relu(kbytes_weights - cfg.target_weight_kbytes)**2
        loss2.persistent = True
    else:
        loss2 = nn.Variable(())
        loss2.d = 0
        loss2.persistent = True
    if cfg.target_activation_kbytes > 0:
        loss3 = F.relu(kbytes_activations - cfg.target_activation_kbytes)**2
        loss3.persistent = True
    else:
        loss3 = nn.Variable(())
        loss3.d = 0
        loss3.persistent = True

    loss = loss1 + cost_lambda2 * loss2 + cost_lambda3 * loss3

    # VALID
    # Create input variables.
    vimage = nn.Variable([vbatch_size, 3, 32, 32])
    vlabel = nn.Variable([vbatch_size, 1])
    # Create predition graph.
    vpred, vhidden = resnet_cifar10(vimage,
                                    num_classes=num_classes,
                                    cfg=cfg,
                                    test=True)
    vpred.persistent = True

    # Create Solver.
    if cfg.optimizer == "adam":
        solver = S.Adam(alpha=learning_rates[0])
    else:
        solver = S.Momentum(learning_rates[0], momentum)

    solver.set_parameters(nn.get_parameters())

    # Training loop (epochs)
    logger.info("Start Training ...")
    i = 0
    best_v_err = 1.0

    # logs of the results
    iters = []
    res_train_err = []
    res_train_loss = []
    res_val_err = []

    # print all variables that exist
    for k in nn.get_parameters():
        print(k)

    res_n_b = collections.OrderedDict()
    res_n_w = collections.OrderedDict()
    res_n_a = collections.OrderedDict()
    res_d_b = collections.OrderedDict()
    res_d_w = collections.OrderedDict()
    res_d_a = collections.OrderedDict()
    res_xmin_b = collections.OrderedDict()
    res_xmin_w = collections.OrderedDict()
    res_xmin_a = collections.OrderedDict()
    res_xmax_b = collections.OrderedDict()
    res_xmax_w = collections.OrderedDict()
    res_xmax_a = collections.OrderedDict()

    for k in nn.get_parameters():
        if (k.split('/')[-1] == 'n') and (k.split('/')[-3] == 'bquant'):
            res_n_b[k] = []
    for k in nn.get_parameters():
        if (k.split('/')[-1] == 'n') and (k.split('/')[-3] == 'Wquant'):
            res_n_w[k] = []
    for k in nn.get_parameters():
        if (k.split('/')[-1] == 'n') and (k.split('/')[-3] == 'Aquant'):
            res_n_a[k] = []
    for k in nn.get_parameters():
        if (k.split('/')[-1] == 'd') and (k.split('/')[-3] == 'bquant'):
            res_d_b[k] = []
    for k in nn.get_parameters():
        if (k.split('/')[-1] == 'd') and (k.split('/')[-3] == 'Wquant'):
            res_d_w[k] = []
    for k in nn.get_parameters():
        if (k.split('/')[-1] == 'd') and (k.split('/')[-3] == 'Aquant'):
            res_d_a[k] = []
    for k in nn.get_parameters():
        if (k.split('/')[-1] == 'xmin') and (k.split('/')[-3] == 'bquant'):
            res_xmin_b[k] = []
    for k in nn.get_parameters():
        if (k.split('/')[-1] == 'xmin') and (k.split('/')[-3] == 'Wquant'):
            res_xmin_w[k] = []
    for k in nn.get_parameters():
        if (k.split('/')[-1] == 'xmin') and (k.split('/')[-3] == 'Aquant'):
            res_xmin_a[k] = []
    for k in nn.get_parameters():
        if (k.split('/')[-1] == 'xmax') and (k.split('/')[-3] == 'bquant'):
            res_xmax_b[k] = []
    for k in nn.get_parameters():
        if (k.split('/')[-1] == 'xmax') and (k.split('/')[-3] == 'Wquant'):
            res_xmax_w[k] = []
    for k in nn.get_parameters():
        if (k.split('/')[-1] == 'xmax') and (k.split('/')[-3] == 'Aquant'):
            res_xmax_a[k] = []

    for epoch in range(len(learning_rates)):
        train_loss = list()
        train_loss1 = list()
        train_loss2 = list()
        train_loss3 = list()
        train_err = list()

        # check whether we need to adapt the learning rate
        if epoch > 0 and learning_rates[epoch - 1] != learning_rates[epoch]:
            solver.set_learning_rate(learning_rates[epoch])

        # Training loop (iterations)
        start_epoch = True
        while data.current != 0 or start_epoch:
            start_epoch = False
            # Next batch
            image.d, label.d = data.next()

            # Training forward/backward
            solver.zero_grad()

            loss.forward()
            loss.backward()

            if weight_decay is not None:
                solver.weight_decay(weight_decay)

            # scale gradients
            if cfg.target_weight_kbytes > 0 or cfg.target_activation_kbytes > 0:
                clip_quant_grads()

            solver.update()
            e = categorical_error(pred.d, label.d)
            train_loss += [loss.d]
            train_loss1 += [loss1.d]
            train_loss2 += [loss2.d]
            train_loss3 += [loss3.d]
            train_err += [e]

            # make sure that parametric values are clipped to correct values (if outside)
            clip_quant_vals()

            # Intermediate Validation (when constraint is set and fulfilled)
            kbytes_weights.forward()
            kbytes_activations.forward()
            if ((cfg.target_weight_kbytes > 0
                 and (cfg.target_weight_kbytes <= 0
                      or float(kbytes_weights.d) <= cfg.target_weight_kbytes)
                 and (cfg.target_activation_kbytes <= 0 or float(
                     kbytes_activations.d) <= cfg.target_activation_kbytes))):

                ve = list()
                start_epoch_ = True
                while vdata.current != 0 or start_epoch_:
                    start_epoch_ = False
                    vimage.d, vlabel.d = vdata.next()
                    vpred.forward()
                    ve += [categorical_error(vpred.d, vlabel.d)]

                v_err = np.array(ve).mean()
                if v_err < best_v_err:
                    best_v_err = v_err
                    nn.save_parameters(
                        os.path.join(cfg.params_dir, 'params_best.h5'))
                    print(
                        f'Best validation error (fulfilling constraints: {best_v_err}'
                    )
                    sys.stdout.flush()
                    sys.stderr.flush()

            i += 1

        # Validation
        ve = list()
        start_epoch = True
        while vdata.current != 0 or start_epoch:
            start_epoch = False
            vimage.d, vlabel.d = vdata.next()
            vpred.forward()
            ve += [categorical_error(vpred.d, vlabel.d)]

        v_err = np.array(ve).mean()
        kbytes_weights.forward()
        kbytes_activations.forward()
        if ((v_err < best_v_err
             and (cfg.target_weight_kbytes <= 0
                  or float(kbytes_weights.d) <= cfg.target_weight_kbytes) and
             (cfg.target_activation_kbytes <= 0 or
              float(kbytes_activations.d) <= cfg.target_activation_kbytes))):
            best_v_err = v_err
            nn.save_parameters(os.path.join(cfg.params_dir, 'params_best.h5'))
            sys.stdout.flush()
            sys.stderr.flush()

        if cfg.target_weight_kbytes > 0:
            print(
                f"Current network size (weights) is {float(kbytes_weights.d):.3f}KB "
                f"(#params: {int(num_weights)}, "
                f"avg. bitwidth: {8. * 1024. * kbytes_weights.d / num_weights})"
            )
            sys.stdout.flush()
            sys.stderr.flush()
        if cfg.target_activation_kbytes > 0:
            print(
                f"Current network size (activations) is {float(kbytes_activations.d):.3f}KB"
            )
            sys.stdout.flush()
            sys.stderr.flush()

        for k in nn.get_parameters():
            if k.split('/')[-1] == 'n':
                print(f'{k}', f'{nn.get_parameters()[k].d}',
                      f'{nn.get_parameters()[k].g}')
                sys.stdout.flush()
                sys.stderr.flush()
                if k.split('/')[-3] == 'bquant':
                    res_n_b[k].append(np.asscalar(nn.get_parameters()[k].d))
                elif k.split('/')[-3] == 'Wquant':
                    res_n_w[k].append(np.asscalar(nn.get_parameters()[k].d))
                elif k.split('/')[-3] == 'Aquant':
                    res_n_a[k].append(np.asscalar(nn.get_parameters()[k].d))

            elif k.split('/')[-1] == 'd':
                print(f'{k}', f'{nn.get_parameters()[k].d}',
                      f'{nn.get_parameters()[k].g}')
                sys.stdout.flush()
                sys.stderr.flush()
                if k.split('/')[-3] == 'bquant':
                    res_d_b[k].append(np.asscalar(nn.get_parameters()[k].d))
                elif k.split('/')[-3] == 'Wquant':
                    res_d_w[k].append(np.asscalar(nn.get_parameters()[k].d))
                elif k.split('/')[-3] == 'Aquant':
                    res_d_a[k].append(np.asscalar(nn.get_parameters()[k].d))

            elif k.split('/')[-1] == 'xmin':
                print(f'{k}', f'{nn.get_parameters()[k].d}',
                      f'{nn.get_parameters()[k].g}')
                sys.stdout.flush()
                sys.stderr.flush()
                if k.split('/')[-3] == 'bquant':
                    res_xmin_b[k].append(np.asscalar(nn.get_parameters()[k].d))
                elif k.split('/')[-3] == 'Wquant':
                    res_xmin_w[k].append(np.asscalar(nn.get_parameters()[k].d))
                elif k.split('/')[-3] == 'Aquant':
                    res_xmin_a[k].append(np.asscalar(nn.get_parameters()[k].d))

            elif k.split('/')[-1] == 'xmax':
                print(f'{k}', f'{nn.get_parameters()[k].d}',
                      f'{nn.get_parameters()[k].g}')
                sys.stdout.flush()
                sys.stderr.flush()
                if k.split('/')[-3] == 'bquant':
                    res_xmax_b[k].append(np.asscalar(nn.get_parameters()[k].d))
                elif k.split('/')[-3] == 'Wquant':
                    res_xmax_w[k].append(np.asscalar(nn.get_parameters()[k].d))
                elif k.split('/')[-3] == 'Aquant':
                    res_xmax_a[k].append(np.asscalar(nn.get_parameters()[k].d))

        # Print
        logger.info(f'epoch={epoch}(iter={i}); '
                    f'overall cost={np.array(train_loss).mean()}; '
                    f'cross-entropy cost={np.array(train_loss1).mean()}; '
                    f'weight-size cost={np.array(train_loss2).mean()}; '
                    f'activations-size cost={np.array(train_loss3).mean()}; '
                    f'TrainErr={np.array(train_err).mean()}; '
                    f'ValidErr={v_err}; BestValidErr={best_v_err}')
        sys.stdout.flush()
        sys.stderr.flush()

        # update the logs
        iters.append(i)
        res_train_err.append(np.array(train_err).mean())
        res_train_loss.append([
            np.array(train_loss).mean(),
            np.array(train_loss1).mean(),
            np.array(train_loss2).mean(),
            np.array(train_loss3).mean()
        ])
        res_val_err.append(np.array(v_err).mean())
        res_ges = np.concatenate([
            np.array(iters)[:, np.newaxis],
            np.array(res_train_err)[:, np.newaxis],
            np.array(res_val_err)[:, np.newaxis],
            np.array(res_train_loss)
        ],
                                 axis=-1)

        # save the results
        np.savetxt(cfg.params_dir + '/results.csv',
                   np.array(res_ges),
                   fmt='%10.8f',
                   header='iter,train_err,val_err,loss,loss1,loss2,loss3',
                   comments='',
                   delimiter=',')

        for rs, res in zip([
                'res_n_b.csv', 'res_n_w.csv', 'res_n_a.csv', 'res_d_b.csv',
                'res_d_w.csv', 'res_d_a.csv', 'res_min_b.csv', 'res_min_w.csv',
                'res_min_a.csv', 'res_max_b.csv', 'res_max_w.csv',
                'res_max_a.csv'
        ], [
                res_n_b, res_n_w, res_n_a, res_d_b, res_d_w, res_d_a,
                res_xmin_b, res_xmin_w, res_xmin_a, res_xmax_b, res_xmax_w,
                res_xmax_a
        ]):
            res_mat = np.array([res[i] for i in res])
            if res_mat.shape[0] > 1 and res_mat.shape[1] > 1:
                np.savetxt(
                    cfg.params_dir + '/' + rs,
                    np.array([[i, j, res_mat[i, j]] for i, j in product(
                        range(res_mat.shape[0]), range(res_mat.shape[1]))]),
                    fmt='%10.8f',
                    comments='',
                    delimiter=',')
Exemplo n.º 3
0
def _create_optimizer(ctx, o, networks, datasets):
    class Optimizer:
        pass

    optimizer = Optimizer()

    optimizer.name = o.name
    optimizer.order = o.order
    optimizer.update_interval = o.update_interval if o.update_interval > 0 else 1
    optimizer.network = networks[o.network_name]
    optimizer.data_iterator = datasets[o.dataset_name].data_iterator

    optimizer.dataset_assign = OrderedDict()
    for d in o.data_variable:
        optimizer.dataset_assign[
            optimizer.network.variables[d.variable_name]] = d.data_name

    optimizer.generator_assign = OrderedDict()
    for g in o.generator_variable:
        optimizer.generator_assign[optimizer.network.variables[
            g.variable_name]] = _get_generator(g)

    optimizer.loss_variables = []
    for l in o.loss_variable:
        optimizer.loss_variables.append(
            optimizer.network.variables[l.variable_name])

    optimizer.parameter_learning_rate_multipliers = OrderedDict()
    for p in o.parameter_variable:
        param_variable_names = [v_name for v_name in optimizer.network.variables.keys(
        ) if v_name.find(p.variable_name) == 0]
        for v_name in param_variable_names:
            optimizer.parameter_learning_rate_multipliers[
                optimizer.network.variables[v_name]] = p.learning_rate_multiplier

    with nn.context_scope(ctx):
        if o.solver.type == 'Adagrad':
            optimizer.solver = S.Adagrad(
                o.solver.adagrad_param.lr, o.solver.adagrad_param.eps)
        elif o.solver.type == 'Adadelta':
            optimizer.solver = S.Adadelta(
                o.solver.adadelta_param.lr, o.solver.adadelta_param.decay, o.solver.adadelta_param.eps)
        elif o.solver.type == 'Adam':
            optimizer.solver = S.Adam(o.solver.adam_param.alpha, o.solver.adam_param.beta1,
                                      o.solver.adam_param.beta2, o.solver.adam_param.eps)
        elif o.solver.type == 'Adamax':
            optimizer.solver = S.Adamax(o.solver.adamax_param.alpha, o.solver.adamax_param.beta1,
                                        o.solver.adamax_param.beta2, o.solver.adamax_param.eps)
        elif o.solver.type == 'Eve':
            p = o.solver.eve_param
            optimizer.solver = S.Eve(
                p.alpha, p.beta1, p.beta2, p.beta3, p.k, p.k2, p.eps)
        elif o.solver.type == 'Momentum':
            optimizer.solver = S.Momentum(
                o.solver.momentum_param.lr, o.solver.momentum_param.momentum)
        elif o.solver.type == 'Nesterov':
            optimizer.solver = S.Nesterov(
                o.solver.nesterov_param.lr, o.solver.nesterov_param.momentum)
        elif o.solver.type == 'RMSprop':
            optimizer.solver = S.RMSprop(
                o.solver.rmsprop_param.lr, o.solver.rmsprop_param.decay, o.solver.rmsprop_param.eps)
        elif o.solver.type == 'Sgd' or o.solver.type == 'SGD':
            optimizer.solver = S.Sgd(o.solver.sgd_param.lr)
        else:
            raise ValueError('Solver "' + o.solver.type +
                             '" is not supported.')

    optimizer.solver.set_parameters({v.name: v.variable_instance for v,
                                     local_lr in optimizer.parameter_learning_rate_multipliers.items() if local_lr > 0.0})

    optimizer.weight_decay = o.solver.weight_decay
    optimizer.lr_decay = o.solver.lr_decay if o.solver.lr_decay > 0.0 else 1.0
    optimizer.lr_decay_interval = o.solver.lr_decay_interval if o.solver.lr_decay_interval > 0 else 1

    optimizer.forward_sequence = optimizer.network.get_forward_sequence(
        optimizer.loss_variables)
    optimizer.backward_sequence = optimizer.network.get_backward_sequence(
        optimizer.loss_variables, optimizer.parameter_learning_rate_multipliers)

    return optimizer