Beispiel #1
0
def _create_optimizer(ctx, o, networks, datasets):
    class Optimizer:
        pass

    optimizer = Optimizer()

    optimizer.name = o.name
    optimizer.order = o.order
    optimizer.update_interval = o.update_interval if o.update_interval > 0 else 1
    optimizer.network = networks[o.network_name]
    optimizer.data_iterator = datasets[o.dataset_name].data_iterator

    optimizer.dataset_assign = OrderedDict()
    for d in o.data_variable:
        optimizer.dataset_assign[
            optimizer.network.variables[d.variable_name]] = d.data_name

    optimizer.generator_assign = OrderedDict()
    for g in o.generator_variable:
        optimizer.generator_assign[optimizer.network.variables[
            g.variable_name]] = _get_generator(g)

    optimizer.loss_variables = []
    for l in o.loss_variable:
        optimizer.loss_variables.append(
            optimizer.network.variables[l.variable_name])

    optimizer.parameter_learning_rate_multipliers = OrderedDict()
    for p in o.parameter_variable:
        param_variable_names = [v_name for v_name in optimizer.network.variables.keys(
        ) if v_name.find(p.variable_name) == 0]
        for v_name in param_variable_names:
            optimizer.parameter_learning_rate_multipliers[
                optimizer.network.variables[v_name]] = p.learning_rate_multiplier

    with nn.context_scope(ctx):
        if o.solver.type == 'Adagrad':
            optimizer.solver = S.Adagrad(
                o.solver.adagrad_param.lr, o.solver.adagrad_param.eps)
        elif o.solver.type == 'Adadelta':
            optimizer.solver = S.Adadelta(
                o.solver.adadelta_param.lr, o.solver.adadelta_param.decay, o.solver.adadelta_param.eps)
        elif o.solver.type == 'Adam':
            optimizer.solver = S.Adam(o.solver.adam_param.alpha, o.solver.adam_param.beta1,
                                      o.solver.adam_param.beta2, o.solver.adam_param.eps)
        elif o.solver.type == 'Adamax':
            optimizer.solver = S.Adamax(o.solver.adamax_param.alpha, o.solver.adamax_param.beta1,
                                        o.solver.adamax_param.beta2, o.solver.adamax_param.eps)
        elif o.solver.type == 'Eve':
            p = o.solver.eve_param
            optimizer.solver = S.Eve(
                p.alpha, p.beta1, p.beta2, p.beta3, p.k, p.k2, p.eps)
        elif o.solver.type == 'Momentum':
            optimizer.solver = S.Momentum(
                o.solver.momentum_param.lr, o.solver.momentum_param.momentum)
        elif o.solver.type == 'Nesterov':
            optimizer.solver = S.Nesterov(
                o.solver.nesterov_param.lr, o.solver.nesterov_param.momentum)
        elif o.solver.type == 'RMSprop':
            optimizer.solver = S.RMSprop(
                o.solver.rmsprop_param.lr, o.solver.rmsprop_param.decay, o.solver.rmsprop_param.eps)
        elif o.solver.type == 'Sgd' or o.solver.type == 'SGD':
            optimizer.solver = S.Sgd(o.solver.sgd_param.lr)
        else:
            raise ValueError('Solver "' + o.solver.type +
                             '" is not supported.')

    optimizer.solver.set_parameters({v.name: v.variable_instance for v,
                                     local_lr in optimizer.parameter_learning_rate_multipliers.items() if local_lr > 0.0})

    optimizer.weight_decay = o.solver.weight_decay
    optimizer.lr_decay = o.solver.lr_decay if o.solver.lr_decay > 0.0 else 1.0
    optimizer.lr_decay_interval = o.solver.lr_decay_interval if o.solver.lr_decay_interval > 0 else 1

    optimizer.forward_sequence = optimizer.network.get_forward_sequence(
        optimizer.loss_variables)
    optimizer.backward_sequence = optimizer.network.get_backward_sequence(
        optimizer.loss_variables, optimizer.parameter_learning_rate_multipliers)

    return optimizer
Beispiel #2
0
def _create_optimizer(ctx, o, networks, datasets):
    class Optimizer:
        pass

    optimizer = Optimizer()

    optimizer.comm = current_communicator()
    comm_size = optimizer.comm.size if optimizer.comm else 1
    optimizer.start_iter = (o.start_iter - 1) // comm_size + \
        1 if o.start_iter > 0 else 0
    optimizer.end_iter = (o.end_iter - 1) // comm_size + \
        1 if o.end_iter > 0 else 0
    optimizer.name = o.name
    optimizer.order = o.order
    optimizer.update_interval = o.update_interval if o.update_interval > 0 else 1
    optimizer.network = networks[o.network_name]
    optimizer.data_iterators = OrderedDict()
    for d in o.dataset_name:
        optimizer.data_iterators[d] = datasets[d].data_iterator

    optimizer.dataset_assign = OrderedDict()
    for d in o.data_variable:
        optimizer.dataset_assign[optimizer.network.variables[
            d.variable_name]] = d.data_name

    optimizer.generator_assign = OrderedDict()
    for g in o.generator_variable:
        optimizer.generator_assign[optimizer.network.variables[
            g.variable_name]] = _get_generator(g)

    optimizer.loss_variables = []
    for l in o.loss_variable:
        optimizer.loss_variables.append(
            optimizer.network.variables[l.variable_name])

    optimizer.parameter_learning_rate_multipliers = OrderedDict()
    for p in o.parameter_variable:
        param_variable_names = _get_matching_variable_names(
            p.variable_name, optimizer.network.variables.keys())
        for v_name in param_variable_names:
            optimizer.parameter_learning_rate_multipliers[
                optimizer.network.
                variables[v_name]] = p.learning_rate_multiplier

    with nn.context_scope(ctx):
        if o.solver.type == 'Adagrad':
            optimizer.solver = S.Adagrad(o.solver.adagrad_param.lr,
                                         o.solver.adagrad_param.eps)
            init_lr = o.solver.adagrad_param.lr
        elif o.solver.type == 'Adadelta':
            optimizer.solver = S.Adadelta(o.solver.adadelta_param.lr,
                                          o.solver.adadelta_param.decay,
                                          o.solver.adadelta_param.eps)
            init_lr = o.solver.adadelta_param.lr
        elif o.solver.type == 'Adam':
            optimizer.solver = S.Adam(o.solver.adam_param.alpha,
                                      o.solver.adam_param.beta1,
                                      o.solver.adam_param.beta2,
                                      o.solver.adam_param.eps)
            init_lr = o.solver.adam_param.alpha
        elif o.solver.type == 'Adamax':
            optimizer.solver = S.Adamax(o.solver.adamax_param.alpha,
                                        o.solver.adamax_param.beta1,
                                        o.solver.adamax_param.beta2,
                                        o.solver.adamax_param.eps)
            init_lr = o.solver.adamax_param.alpha
        elif o.solver.type == 'AdaBound':
            optimizer.solver = S.AdaBound(o.solver.adabound_param.alpha,
                                          o.solver.adabound_param.beta1,
                                          o.solver.adabound_param.beta2,
                                          o.solver.adabound_param.eps,
                                          o.solver.adabound_param.final_lr,
                                          o.solver.adabound_param.gamma)
            init_lr = o.solver.adabound_param.alpha
        elif o.solver.type == 'AMSGRAD':
            optimizer.solver = S.AMSGRAD(o.solver.amsgrad_param.alpha,
                                         o.solver.amsgrad_param.beta1,
                                         o.solver.amsgrad_param.beta2,
                                         o.solver.amsgrad_param.eps)
            init_lr = o.solver.amsgrad_param.alpha
        elif o.solver.type == 'AMSBound':
            optimizer.solver = S.AMSBound(o.solver.amsbound_param.alpha,
                                          o.solver.amsbound_param.beta1,
                                          o.solver.amsbound_param.beta2,
                                          o.solver.amsbound_param.eps,
                                          o.solver.amsbound_param.final_lr,
                                          o.solver.amsbound_param.gamma)
            init_lr = o.solver.amsbound_param.alpha
        elif o.solver.type == 'Eve':
            p = o.solver.eve_param
            optimizer.solver = S.Eve(p.alpha, p.beta1, p.beta2, p.beta3, p.k,
                                     p.k2, p.eps)
            init_lr = p.alpha
        elif o.solver.type == 'Momentum':
            optimizer.solver = S.Momentum(o.solver.momentum_param.lr,
                                          o.solver.momentum_param.momentum)
            init_lr = o.solver.momentum_param.lr
        elif o.solver.type == 'Nesterov':
            optimizer.solver = S.Nesterov(o.solver.nesterov_param.lr,
                                          o.solver.nesterov_param.momentum)
            init_lr = o.solver.nesterov_param.lr
        elif o.solver.type == 'RMSprop':
            optimizer.solver = S.RMSprop(o.solver.rmsprop_param.lr,
                                         o.solver.rmsprop_param.decay,
                                         o.solver.rmsprop_param.eps)
            init_lr = o.solver.rmsprop_param.lr
        elif o.solver.type == 'Sgd' or o.solver.type == 'SGD':
            optimizer.solver = S.Sgd(o.solver.sgd_param.lr)
            init_lr = o.solver.sgd_param.lr
        else:
            raise ValueError('Solver "' + o.solver.type +
                             '" is not supported.')

    parameters = {
        v.name: v.variable_instance
        for v, local_lr in
        optimizer.parameter_learning_rate_multipliers.items() if local_lr > 0.0
    }
    optimizer.solver.set_parameters(parameters)
    optimizer.parameters = OrderedDict(
        sorted(parameters.items(), key=lambda x: x[0]))

    optimizer.weight_decay = o.solver.weight_decay

    # keep following 2 lines for backward compatibility
    optimizer.lr_decay = o.solver.lr_decay if o.solver.lr_decay > 0.0 else 1.0
    optimizer.lr_decay_interval = o.solver.lr_decay_interval if o.solver.lr_decay_interval > 0 else 1
    optimizer.solver.set_states_from_protobuf(o)

    optimizer.comm = current_communicator()
    comm_size = optimizer.comm.size if optimizer.comm else 1
    optimizer.scheduler = ExponentialScheduler(init_lr, 1.0, 1)

    if o.solver.lr_scheduler_type == 'Polynomial':
        if o.solver.polynomial_scheduler_param.power != 0.0:
            optimizer.scheduler = PolynomialScheduler(
                init_lr,
                o.solver.polynomial_scheduler_param.max_iter // comm_size,
                o.solver.polynomial_scheduler_param.power)
    elif o.solver.lr_scheduler_type == 'Cosine':
        optimizer.scheduler = CosineScheduler(
            init_lr, o.solver.cosine_scheduler_param.max_iter // comm_size)
    elif o.solver.lr_scheduler_type == 'Exponential':
        if o.solver.exponential_scheduler_param.gamma != 1.0:
            optimizer.scheduler = ExponentialScheduler(
                init_lr, o.solver.exponential_scheduler_param.gamma,
                o.solver.exponential_scheduler_param.iter_interval //
                comm_size if
                o.solver.exponential_scheduler_param.iter_interval > comm_size
                else 1)
    elif o.solver.lr_scheduler_type == 'Step':
        if o.solver.step_scheduler_param.gamma != 1.0 and len(
                o.solver.step_scheduler_param.iter_steps) > 0:
            optimizer.scheduler = StepScheduler(
                init_lr, o.solver.step_scheduler_param.gamma, [
                    step // comm_size
                    for step in o.solver.step_scheduler_param.iter_steps
                ])
    elif o.solver.lr_scheduler_type == 'Custom':
        # ToDo
        raise NotImplementedError()
    elif o.solver.lr_scheduler_type == '':
        if o.solver.lr_decay_interval != 0 or o.solver.lr_decay != 0.0:
            optimizer.scheduler = ExponentialScheduler(
                init_lr, o.solver.lr_decay if o.solver.lr_decay > 0.0 else 1.0,
                o.solver.lr_decay_interval //
                comm_size if o.solver.lr_decay_interval > comm_size else 1)
    else:
        raise ValueError('Learning Rate Scheduler "' +
                         o.solver.lr_scheduler_type + '" is not supported.')

    if o.solver.lr_warmup_scheduler_type == 'Linear':
        if o.solver.linear_warmup_scheduler_param.warmup_iter >= comm_size:
            optimizer.scheduler = LinearWarmupScheduler(
                optimizer.scheduler,
                o.solver.linear_warmup_scheduler_param.warmup_iter //
                comm_size)

    optimizer.forward_sequence = optimizer.network.get_forward_sequence(
        optimizer.loss_variables)
    optimizer.backward_sequence = optimizer.network.get_backward_sequence(
        optimizer.loss_variables,
        optimizer.parameter_learning_rate_multipliers)

    return optimizer
Beispiel #3
0
def train():
    if Config.USE_NW:
        env = Environment('Pong-v0')
    else:
        env = gym.make('Pong-v0')

    extension_module = Config.CONTEXT
    logger.info("Running in {}".format(extension_module))
    ctx = extension_context(extension_module, device_id=Config.DEVICE_ID)
    nn.set_default_context(ctx)

    monitor = Monitor(Config.MONITOR_PATH)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=1)
    monitor_reward = MonitorSeries("Training reward", monitor, interval=1)
    monitor_q = MonitorSeries("Training q", monitor, interval=1)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=1)

    # placeholder
    image = nn.Variable([
        Config.BATCH_SIZE, Config.STATE_LENGTH, Config.FRAME_WIDTH,
        Config.FRAME_HEIGHT
    ])
    image_target = nn.Variable([
        Config.BATCH_SIZE, Config.STATE_LENGTH, Config.FRAME_WIDTH,
        Config.FRAME_HEIGHT
    ])

    nn.clear_parameters()

    # create network
    with nn.parameter_scope("dqn"):
        q = dqn(image, test=False)
        q.prersistent = True  # Not to clear at backward
    with nn.parameter_scope("target"):
        target_q = dqn(image_target, test=False)
        target_q.prersistent = True  # Not to clear at backward

    # loss definition
    a = nn.Variable([Config.BATCH_SIZE, 1])
    q_val = F.sum(F.one_hot(a, (6, )) * q, axis=1, keepdims=True)
    t = nn.Variable([Config.BATCH_SIZE, 1])
    loss = F.mean(F.squared_error(t, q_val))

    if Config.RESUME:
        logger.info('load model: {}'.format(Config.RESUME))
        nn.load_parameters(Config.RESUME)

    # setup solver
    # update dqn parameter only
    solver = S.RMSprop(lr=Config.LEARNING_RATE,
                       decay=Config.DECAY,
                       eps=Config.EPSILON)
    with nn.parameter_scope("dqn"):
        solver.set_parameters(nn.get_parameters())

    # training
    epsilon = Config.INIT_EPSILON
    experiences = []
    step = 0
    for i in range(Config.EPISODE_LENGTH):
        logger.info("EPISODE {}".format(i))
        done = False
        observation = env.reset()
        for i in range(30):
            observation_next, reward, done, info = env.step(0)
            observation_next = preprocess_frame(observation_next)
        # join 4 frame
        state = [observation_next for _ in xrange(Config.STATE_LENGTH)]
        state = np.stack(state, axis=0)
        total_reward = 0
        while not done:
            # select action
            if step % Config.ACTION_INTERVAL == 0:
                if random.random() > epsilon or len(
                        experiences) >= Config.REPLAY_MEMORY_SIZE:
                    # inference
                    image.d = state
                    q.forward()
                    action = np.argmax(q.d)
                else:
                    # random action
                    if Config.USE_NW:
                        action = env.sample()
                    else:
                        action = env.action_space.sample()  # TODO refactor
                if epsilon > Config.MIN_EPSILON:
                    epsilon -= Config.EPSILON_REDUCTION_PER_STEP

            # get next environment
            observation_next, reward, done, info = env.step(action)
            observation_next = preprocess_frame(observation_next)
            total_reward += reward
            # TODO clip reward

            # update replay memory (FIFO)
            state_next = np.append(state[1:, :, :],
                                   observation_next[np.newaxis, :, :],
                                   axis=0)
            experiences.append((state_next, reward, action, state, done))
            if len(experiences) > Config.REPLAY_MEMORY_SIZE:
                experiences.pop(0)

            # update network
            if step % Config.NET_UPDATE_INTERVAL == 0 and len(
                    experiences) > Config.INIT_REPLAY_SIZE:
                logger.info("update {}".format(step))
                batch = random.sample(experiences, Config.BATCH_SIZE)
                batch_observation_next = np.array([b[0] for b in batch])
                batch_reward = np.array([b[1] for b in batch])
                batch_action = np.array([b[2] for b in batch])
                batch_observation = np.array([b[3] for b in batch])
                batch_done = np.array([b[4] for b in batch], dtype=np.float32)

                batch_reward = batch_reward[:, np.newaxis]
                batch_action = batch_action[:, np.newaxis]
                batch_done = batch_done[:, np.newaxis]

                image.d = batch_observation.astype(np.float32)
                image_target.d = batch_observation_next.astype(np.float32)
                a.d = batch_action
                q_val.forward()  # XXX
                target_q.forward()
                t.d = batch_reward + (1 - batch_done) * Config.GAMMA * np.max(
                    target_q.d, axis=1, keepdims=True)
                solver.zero_grad()
                loss.forward()
                loss.backward()

                monitor_loss.add(step, loss.d.copy())
                monitor_reward.add(step, total_reward)
                monitor_q.add(step, np.mean(q.d.copy()))
                monitor_time.add(step)
                # TODO weight clip
                solver.update()
                logger.info("update done {}".format(step))

            # update target network
            if step % Config.TARGET_NET_UPDATE_INTERVAL == 0:
                # copy parameter from dqn to target
                with nn.parameter_scope("dqn"):
                    src = nn.get_parameters()
                with nn.parameter_scope("target"):
                    dst = nn.get_parameters()
                for (s_key, s_val), (d_key,
                                     d_val) in zip(src.items(), dst.items()):
                    # Variable#d method is reference
                    d_val.d = s_val.d.copy()

            if step % Config.MODEL_SAVE_INTERVAL == 0:
                logger.info("save model")
                nn.save_parameters("model_{}.h5".format(step))

            step += 1
            observation = observation_next
            state = state_next