Source code for mushroom_rl_benchmark.builders.network.sac_network

import torch
import torch.nn as nn
import torch.nn.functional as F


[docs] class SACCriticNetwork(nn.Module):
[docs] def __init__(self, input_shape, output_shape, n_features, **kwargs): super().__init__() n_input = input_shape[-1] n_output = output_shape[0] self._h1 = nn.Linear(n_input, n_features) self._h2 = nn.Linear(n_features, n_features) self._h3 = nn.Linear(n_features, n_output) nn.init.xavier_uniform_(self._h1.weight, gain=nn.init.calculate_gain('relu')) nn.init.xavier_uniform_(self._h2.weight, gain=nn.init.calculate_gain('relu')) nn.init.xavier_uniform_(self._h3.weight, gain=nn.init.calculate_gain('linear'))
[docs] def forward(self, state, action): state_action = torch.cat((state.float(), action.float()), dim=1) features1 = F.relu(self._h1(state_action)) features2 = F.relu(self._h2(features1)) q = self._h3(features2) return torch.squeeze(q)
[docs] class SACActorNetwork(nn.Module):
[docs] def __init__(self, input_shape, output_shape, n_features, **kwargs): super(SACActorNetwork, self).__init__() n_input = input_shape[-1] n_output = output_shape[0] self._h1 = nn.Linear(n_input, n_features) self._h2 = nn.Linear(n_features, n_features) self._h3 = nn.Linear(n_features, n_output) nn.init.xavier_uniform_(self._h1.weight, gain=nn.init.calculate_gain('relu')) nn.init.xavier_uniform_(self._h2.weight, gain=nn.init.calculate_gain('relu')) nn.init.xavier_uniform_(self._h3.weight, gain=nn.init.calculate_gain('linear'))
[docs] def forward(self, state): features1 = F.relu(self._h1(torch.squeeze(state, 1).float())) features2 = F.relu(self._h2(features1)) a = self._h3(features2) return a