Source code for mushroom_rl_benchmark.builders.value.td.true_online_sarsa_lambda

from mushroom_rl.algorithms.value import TrueOnlineSARSALambda
from mushroom_rl.policy import EpsGreedy

from mushroom_rl.utils.parameters import ExponentialParameter, Parameter

from .td_continuous import TDContinuousBuilder


[docs] class TrueOnlineSarsaLambdaBuilder(TDContinuousBuilder): """ AgentBuilder for True Online Sarsa(Lambda) Continuous. Using tiles as function approximator. """
[docs] def __init__(self, policy, learning_rate, lambda_coeff, epsilon, epsilon_test, n_tilings, n_tiles): """ Constructor. """ super().__init__(policy, learning_rate, lambda_coeff, epsilon, epsilon_test, n_tilings, n_tiles)
[docs] def build(self, mdp_info): features, approximator_params = self._build_function_approximation(mdp_info) return TrueOnlineSARSALambda(mdp_info, self.policy, self.learning_rate, self.lambda_coeff, features, approximator_params)
[docs] @classmethod def default(cls, alpha=.1, lambda_coeff=0.9, epsilon=0., decay_eps=0., epsilon_test=0., n_tilings=10, n_tiles=10, get_default_dict=False): if decay_eps == 0: epsilon_p = Parameter(value=epsilon) else: epsilon_p = ExponentialParameter(value=epsilon, exp=decay_eps) defaults = locals() epsilon_test_p = Parameter(value=epsilon_test) policy = EpsGreedy(epsilon=epsilon_p) lambda_coeff_p = Parameter(lambda_coeff) learning_rate = Parameter(alpha / n_tilings) builder = cls(policy, learning_rate, lambda_coeff_p, epsilon_p, epsilon_test_p, n_tilings=n_tilings, n_tiles=n_tiles) if get_default_dict: return builder, defaults else: return builder