Source code for mushroom_rl_benchmark.builders.agent_builder

from copy import deepcopy
import mushroom_rl.utils.preprocessors as m_prep


[docs] class AgentBuilder: """ Base class to spawn instances of a MushroomRL agent """
[docs] def __init__(self, n_steps_per_fit=None, n_episodes_per_fit=None, compute_policy_entropy=True, compute_entropy_with_states=False, compute_value_function=True, preprocessors=None): """ Initialize AgentBuilder """ assert (n_episodes_per_fit is None and n_steps_per_fit is not None) or \ (n_episodes_per_fit is not None and n_steps_per_fit is None) self._preprocessors = None self._n_steps_per_fit = n_steps_per_fit self._n_episodes_per_fit = n_episodes_per_fit self.set_preprocessors(preprocessors) self.compute_policy_entropy = compute_policy_entropy self.compute_entropy_with_states = compute_entropy_with_states self.compute_value_function = compute_value_function
[docs] def get_fit_params(self): """ Get n_steps_per_fit and n_episodes_per_fit for the specific AgentBuilder """ return dict(n_steps_per_fit=self._n_steps_per_fit, n_episodes_per_fit=self._n_episodes_per_fit)
[docs] def set_preprocessors(self, preprocessors): """ Set preprocessor for the specific AgentBuilder Args: preprocessors: list of preprocessor classes. """ if preprocessors: preprocessors = preprocessors if isinstance(preprocessors, list) else [preprocessors] self._preprocessors = [getattr(m_prep, p) if isinstance(p, str) else p for p in preprocessors] else: self._preprocessors = list()
[docs] def get_preprocessors(self): """ Get preprocessors for the specific AgentBuilder """ return self._preprocessors
[docs] def copy(self): """ Create a deepcopy of the AgentBuilder and return it """ return deepcopy(self)
[docs] def build(self, mdp_info): """ Build and return the AgentBuilder Args: mdp_info (MDPInfo): information about the environment. """ raise NotImplementedError('AgentBuilder is an abstract class')
[docs] def compute_Q(self, agent, states): """ Compute the Q Value for an AgentBuilder Args: agent (Agent): the considered agent; states (np.ndarray): the set of states over which we need to compute the Q function. """ raise NotImplementedError('AgentBuilder is an abstract class')
[docs] def set_eval_mode(self, agent, eval): """ Set the eval mode for the agent. This function can be overwritten by any agent builder to setup specific evaluation mode for the agent. Args: agent (Agent): the considered agent; eval (bool): whether to set eval mode (true) or learn mode. """ pass
[docs] @classmethod def default(cls, get_default_dict=False, **kwargs): """ Create a default initialization for the specific AgentBuilder and return it """ raise NotImplementedError('AgentBuilder is an abstract class')