Source code for mushroom_rl_benchmark.builders.value.dqn.dqn

import torch.nn.functional as F
import torch.optim as optim

from mushroom_rl.algorithms.value import DQN
from mushroom_rl.approximators.parametric import TorchApproximator
from mushroom_rl.policy import EpsGreedy
from mushroom_rl.utils.parameters import LinearParameter, Parameter

from mushroom_rl_benchmark.builders import AgentBuilder
from mushroom_rl_benchmark.builders.network import DQNNetwork


[docs] class DQNBuilder(AgentBuilder): """ AgentBuilder for Deep Q-Network (DQN). """
[docs] def __init__(self, policy, approximator, approximator_params, alg_params, n_steps_per_fit=1): """ Constructor. Args: policy (Policy): policy class; approximator (dict): Q-function approximator; approximator_params (dict): parameters of the Q-function approximator; alg_params (dict): parameters for the algorithm; n_steps_per_fit (int, 1): number of steps per fit. """ self.policy = policy self.approximator = approximator self.approximator_params = approximator_params self.alg_params = alg_params super().__init__(n_steps_per_fit=n_steps_per_fit, compute_policy_entropy=False)
[docs] def build(self, mdp_info): self.approximator_params['input_shape'] = mdp_info.observation_space.shape self.approximator_params['output_shape'] = (mdp_info.action_space.n,) self.approximator_params['n_actions'] = mdp_info.action_space.n self.epsilon = LinearParameter(value=1, threshold_value=.05, n=1000000) self.epsilon_test = Parameter(value=.01) return DQN(mdp_info, self.policy, self.approximator, self.approximator_params, **self.alg_params)
[docs] def compute_Q(self, agent, states): q_max = agent.approximator(states).max() return q_max
[docs] def set_eval_mode(self, agent, eval): if eval: agent.policy.set_epsilon(self.epsilon_test) else: agent.policy.set_epsilon(self.epsilon)
[docs] @classmethod def default(cls, lr=.0001, network=DQNNetwork, initial_replay_size=50000, max_replay_size=1000000, batch_size=32, target_update_frequency=2500, n_steps_per_fit=1, use_cuda=False, get_default_dict=False): defaults = locals() policy = EpsGreedy(epsilon=Parameter(value=1.)) approximator_params = dict( network=network, optimizer={ 'class': optim.Adam, 'params': {'lr': lr}}, loss=F.smooth_l1_loss, use_cuda=use_cuda) alg_params = dict( initial_replay_size=initial_replay_size, max_replay_size=max_replay_size, batch_size=batch_size, target_update_frequency=target_update_frequency ) builder = cls(policy, TorchApproximator, approximator_params, alg_params, n_steps_per_fit) if get_default_dict: return builder, defaults else: return builder