Source code for mushroom_rl_benchmark.builders.actor_critic.deep_actor_critic.trpo

import torch.optim as optim
import torch.nn.functional as F

from mushroom_rl.algorithms.actor_critic import TRPO
from mushroom_rl.policy import GaussianTorchPolicy

from mushroom_rl_benchmark.builders import AgentBuilder
from mushroom_rl_benchmark.builders.network import TRPONetwork as Network


[docs] class TRPOBuilder(AgentBuilder): """ AgentBuilder for Trust Region Policy optimization algorithm (TRPO) """
[docs] def __init__(self, policy_params, critic_params, alg_params, n_steps_per_fit=3000, preprocessors=None): """ Constructor. Args: policy_params (dict): parameters for the policy; critic_params (dict): parameters for the critic; alg_params (dict): parameters for the algorithm; n_steps_per_fit (int, 3000): number of steps per fit; preprocessors (list, None): list of preprocessors. """ self.policy_params = policy_params self.critic_params = critic_params self.alg_params = alg_params super().__init__(n_steps_per_fit=n_steps_per_fit, preprocessors=preprocessors)
[docs] def build(self, mdp_info): policy = GaussianTorchPolicy( Network, mdp_info.observation_space.shape, mdp_info.action_space.shape, **self.policy_params) self.critic_params["input_shape"] = mdp_info.observation_space.shape self.alg_params['critic_params'] = self.critic_params return TRPO(mdp_info, policy, **self.alg_params)
[docs] def compute_Q(self, agent, states): return agent._V(states).mean()
[docs] @classmethod def default(cls, critic_lr=3e-4, critic_network=Network, max_kl=1e-2, ent_coeff=0.0, lam=.95, batch_size=64, n_features=32, critic_fit_params=None, n_steps_per_fit=3000, n_epochs_line_search=10, n_epochs_cg=100, cg_damping=1e-2, cg_residual_tol=1e-10, std_0=1.0, preprocessors=None, use_cuda=False, get_default_dict=False): defaults = locals() policy_params = dict( std_0=std_0, n_features=n_features, use_cuda=use_cuda) critic_params = dict( network=critic_network, optimizer={ 'class': optim.Adam, 'params': {'lr': critic_lr}}, loss=F.mse_loss, n_features=n_features, batch_size=batch_size, output_shape=(1,)) alg_params = dict( ent_coeff=ent_coeff, max_kl=max_kl, lam=lam, n_epochs_line_search=n_epochs_line_search, n_epochs_cg=n_epochs_cg, cg_damping=cg_damping, cg_residual_tol=cg_residual_tol, critic_fit_params=critic_fit_params) builder = cls(policy_params, critic_params, alg_params, n_steps_per_fit=n_steps_per_fit, preprocessors=preprocessors) if get_default_dict: return builder, defaults else: return builder