Exemplo n.º 1
0
def _create_optimizer(ctx, o, networks, datasets):
    class Optimizer:
        pass

    optimizer = Optimizer()

    optimizer.comm = current_communicator()
    comm_size = optimizer.comm.size if optimizer.comm else 1
    optimizer.start_iter = (o.start_iter - 1) // comm_size + \
        1 if o.start_iter > 0 else 0
    optimizer.end_iter = (o.end_iter - 1) // comm_size + \
        1 if o.end_iter > 0 else 0
    optimizer.name = o.name
    optimizer.order = o.order
    optimizer.update_interval = o.update_interval if o.update_interval > 0 else 1
    optimizer.network = networks[o.network_name]
    optimizer.data_iterators = OrderedDict()
    for d in o.dataset_name:
        optimizer.data_iterators[d] = datasets[d].data_iterator

    optimizer.dataset_assign = OrderedDict()
    for d in o.data_variable:
        optimizer.dataset_assign[optimizer.network.variables[
            d.variable_name]] = d.data_name

    optimizer.generator_assign = OrderedDict()
    for g in o.generator_variable:
        optimizer.generator_assign[optimizer.network.variables[
            g.variable_name]] = _get_generator(g)

    optimizer.loss_variables = []
    for l in o.loss_variable:
        optimizer.loss_variables.append(
            optimizer.network.variables[l.variable_name])

    optimizer.parameter_learning_rate_multipliers = OrderedDict()
    for p in o.parameter_variable:
        param_variable_names = _get_matching_variable_names(
            p.variable_name, optimizer.network.variables.keys())
        for v_name in param_variable_names:
            optimizer.parameter_learning_rate_multipliers[
                optimizer.network.
                variables[v_name]] = p.learning_rate_multiplier

    with nn.context_scope(ctx):
        if o.solver.type == 'Adagrad':
            optimizer.solver = S.Adagrad(o.solver.adagrad_param.lr,
                                         o.solver.adagrad_param.eps)
            init_lr = o.solver.adagrad_param.lr
        elif o.solver.type == 'Adadelta':
            optimizer.solver = S.Adadelta(o.solver.adadelta_param.lr,
                                          o.solver.adadelta_param.decay,
                                          o.solver.adadelta_param.eps)
            init_lr = o.solver.adadelta_param.lr
        elif o.solver.type == 'Adam':
            optimizer.solver = S.Adam(o.solver.adam_param.alpha,
                                      o.solver.adam_param.beta1,
                                      o.solver.adam_param.beta2,
                                      o.solver.adam_param.eps)
            init_lr = o.solver.adam_param.alpha
        elif o.solver.type == 'Adamax':
            optimizer.solver = S.Adamax(o.solver.adamax_param.alpha,
                                        o.solver.adamax_param.beta1,
                                        o.solver.adamax_param.beta2,
                                        o.solver.adamax_param.eps)
            init_lr = o.solver.adamax_param.alpha
        elif o.solver.type == 'AdaBound':
            optimizer.solver = S.AdaBound(o.solver.adabound_param.alpha,
                                          o.solver.adabound_param.beta1,
                                          o.solver.adabound_param.beta2,
                                          o.solver.adabound_param.eps,
                                          o.solver.adabound_param.final_lr,
                                          o.solver.adabound_param.gamma)
            init_lr = o.solver.adabound_param.alpha
        elif o.solver.type == 'AMSGRAD':
            optimizer.solver = S.AMSGRAD(o.solver.amsgrad_param.alpha,
                                         o.solver.amsgrad_param.beta1,
                                         o.solver.amsgrad_param.beta2,
                                         o.solver.amsgrad_param.eps)
            init_lr = o.solver.amsgrad_param.alpha
        elif o.solver.type == 'AMSBound':
            optimizer.solver = S.AMSBound(o.solver.amsbound_param.alpha,
                                          o.solver.amsbound_param.beta1,
                                          o.solver.amsbound_param.beta2,
                                          o.solver.amsbound_param.eps,
                                          o.solver.amsbound_param.final_lr,
                                          o.solver.amsbound_param.gamma)
            init_lr = o.solver.amsbound_param.alpha
        elif o.solver.type == 'Eve':
            p = o.solver.eve_param
            optimizer.solver = S.Eve(p.alpha, p.beta1, p.beta2, p.beta3, p.k,
                                     p.k2, p.eps)
            init_lr = p.alpha
        elif o.solver.type == 'Momentum':
            optimizer.solver = S.Momentum(o.solver.momentum_param.lr,
                                          o.solver.momentum_param.momentum)
            init_lr = o.solver.momentum_param.lr
        elif o.solver.type == 'Nesterov':
            optimizer.solver = S.Nesterov(o.solver.nesterov_param.lr,
                                          o.solver.nesterov_param.momentum)
            init_lr = o.solver.nesterov_param.lr
        elif o.solver.type == 'RMSprop':
            optimizer.solver = S.RMSprop(o.solver.rmsprop_param.lr,
                                         o.solver.rmsprop_param.decay,
                                         o.solver.rmsprop_param.eps)
            init_lr = o.solver.rmsprop_param.lr
        elif o.solver.type == 'Sgd' or o.solver.type == 'SGD':
            optimizer.solver = S.Sgd(o.solver.sgd_param.lr)
            init_lr = o.solver.sgd_param.lr
        else:
            raise ValueError('Solver "' + o.solver.type +
                             '" is not supported.')

    parameters = {
        v.name: v.variable_instance
        for v, local_lr in
        optimizer.parameter_learning_rate_multipliers.items() if local_lr > 0.0
    }
    optimizer.solver.set_parameters(parameters)
    optimizer.parameters = OrderedDict(
        sorted(parameters.items(), key=lambda x: x[0]))

    optimizer.weight_decay = o.solver.weight_decay

    # keep following 2 lines for backward compatibility
    optimizer.lr_decay = o.solver.lr_decay if o.solver.lr_decay > 0.0 else 1.0
    optimizer.lr_decay_interval = o.solver.lr_decay_interval if o.solver.lr_decay_interval > 0 else 1
    optimizer.solver.set_states_from_protobuf(o)

    optimizer.comm = current_communicator()
    comm_size = optimizer.comm.size if optimizer.comm else 1
    optimizer.scheduler = ExponentialScheduler(init_lr, 1.0, 1)

    if o.solver.lr_scheduler_type == 'Polynomial':
        if o.solver.polynomial_scheduler_param.power != 0.0:
            optimizer.scheduler = PolynomialScheduler(
                init_lr,
                o.solver.polynomial_scheduler_param.max_iter // comm_size,
                o.solver.polynomial_scheduler_param.power)
    elif o.solver.lr_scheduler_type == 'Cosine':
        optimizer.scheduler = CosineScheduler(
            init_lr, o.solver.cosine_scheduler_param.max_iter // comm_size)
    elif o.solver.lr_scheduler_type == 'Exponential':
        if o.solver.exponential_scheduler_param.gamma != 1.0:
            optimizer.scheduler = ExponentialScheduler(
                init_lr, o.solver.exponential_scheduler_param.gamma,
                o.solver.exponential_scheduler_param.iter_interval //
                comm_size if
                o.solver.exponential_scheduler_param.iter_interval > comm_size
                else 1)
    elif o.solver.lr_scheduler_type == 'Step':
        if o.solver.step_scheduler_param.gamma != 1.0 and len(
                o.solver.step_scheduler_param.iter_steps) > 0:
            optimizer.scheduler = StepScheduler(
                init_lr, o.solver.step_scheduler_param.gamma, [
                    step // comm_size
                    for step in o.solver.step_scheduler_param.iter_steps
                ])
    elif o.solver.lr_scheduler_type == 'Custom':
        # ToDo
        raise NotImplementedError()
    elif o.solver.lr_scheduler_type == '':
        if o.solver.lr_decay_interval != 0 or o.solver.lr_decay != 0.0:
            optimizer.scheduler = ExponentialScheduler(
                init_lr, o.solver.lr_decay if o.solver.lr_decay > 0.0 else 1.0,
                o.solver.lr_decay_interval //
                comm_size if o.solver.lr_decay_interval > comm_size else 1)
    else:
        raise ValueError('Learning Rate Scheduler "' +
                         o.solver.lr_scheduler_type + '" is not supported.')

    if o.solver.lr_warmup_scheduler_type == 'Linear':
        if o.solver.linear_warmup_scheduler_param.warmup_iter >= comm_size:
            optimizer.scheduler = LinearWarmupScheduler(
                optimizer.scheduler,
                o.solver.linear_warmup_scheduler_param.warmup_iter //
                comm_size)

    optimizer.forward_sequence = optimizer.network.get_forward_sequence(
        optimizer.loss_variables)
    optimizer.backward_sequence = optimizer.network.get_backward_sequence(
        optimizer.loss_variables,
        optimizer.parameter_learning_rate_multipliers)

    return optimizer
Exemplo n.º 2
0
def _create_optimizer(ctx, o, networks, datasets):
    class Optimizer:
        pass

    optimizer = Optimizer()

    optimizer.name = o.name
    optimizer.order = o.order
    optimizer.update_interval = o.update_interval if o.update_interval > 0 else 1
    optimizer.network = networks[o.network_name]
    optimizer.data_iterator = datasets[o.dataset_name].data_iterator

    optimizer.dataset_assign = OrderedDict()
    for d in o.data_variable:
        optimizer.dataset_assign[
            optimizer.network.variables[d.variable_name]] = d.data_name

    optimizer.generator_assign = OrderedDict()
    for g in o.generator_variable:
        optimizer.generator_assign[optimizer.network.variables[
            g.variable_name]] = _get_generator(g)

    optimizer.loss_variables = []
    for l in o.loss_variable:
        optimizer.loss_variables.append(
            optimizer.network.variables[l.variable_name])

    optimizer.parameter_learning_rate_multipliers = OrderedDict()
    for p in o.parameter_variable:
        param_variable_names = [v_name for v_name in optimizer.network.variables.keys(
        ) if v_name.find(p.variable_name) == 0]
        for v_name in param_variable_names:
            optimizer.parameter_learning_rate_multipliers[
                optimizer.network.variables[v_name]] = p.learning_rate_multiplier

    with nn.context_scope(ctx):
        if o.solver.type == 'Adagrad':
            optimizer.solver = S.Adagrad(
                o.solver.adagrad_param.lr, o.solver.adagrad_param.eps)
        elif o.solver.type == 'Adadelta':
            optimizer.solver = S.Adadelta(
                o.solver.adadelta_param.lr, o.solver.adadelta_param.decay, o.solver.adadelta_param.eps)
        elif o.solver.type == 'Adam':
            optimizer.solver = S.Adam(o.solver.adam_param.alpha, o.solver.adam_param.beta1,
                                      o.solver.adam_param.beta2, o.solver.adam_param.eps)
        elif o.solver.type == 'Adamax':
            optimizer.solver = S.Adamax(o.solver.adamax_param.alpha, o.solver.adamax_param.beta1,
                                        o.solver.adamax_param.beta2, o.solver.adamax_param.eps)
        elif o.solver.type == 'Eve':
            p = o.solver.eve_param
            optimizer.solver = S.Eve(
                p.alpha, p.beta1, p.beta2, p.beta3, p.k, p.k2, p.eps)
        elif o.solver.type == 'Momentum':
            optimizer.solver = S.Momentum(
                o.solver.momentum_param.lr, o.solver.momentum_param.momentum)
        elif o.solver.type == 'Nesterov':
            optimizer.solver = S.Nesterov(
                o.solver.nesterov_param.lr, o.solver.nesterov_param.momentum)
        elif o.solver.type == 'RMSprop':
            optimizer.solver = S.RMSprop(
                o.solver.rmsprop_param.lr, o.solver.rmsprop_param.decay, o.solver.rmsprop_param.eps)
        elif o.solver.type == 'Sgd' or o.solver.type == 'SGD':
            optimizer.solver = S.Sgd(o.solver.sgd_param.lr)
        else:
            raise ValueError('Solver "' + o.solver.type +
                             '" is not supported.')

    optimizer.solver.set_parameters({v.name: v.variable_instance for v,
                                     local_lr in optimizer.parameter_learning_rate_multipliers.items() if local_lr > 0.0})

    optimizer.weight_decay = o.solver.weight_decay
    optimizer.lr_decay = o.solver.lr_decay if o.solver.lr_decay > 0.0 else 1.0
    optimizer.lr_decay_interval = o.solver.lr_decay_interval if o.solver.lr_decay_interval > 0 else 1

    optimizer.forward_sequence = optimizer.network.get_forward_sequence(
        optimizer.loss_variables)
    optimizer.backward_sequence = optimizer.network.get_backward_sequence(
        optimizer.loss_variables, optimizer.parameter_learning_rate_multipliers)

    return optimizer