Source code for mushroom_rl_benchmark.builders.network.a2c_network

import torch
import torch.nn as nn


[docs] class A2CNetwork(nn.Module):
[docs] def __init__(self, input_shape, output_shape, n_features, **kwargs): super(A2CNetwork, self).__init__() n_input = input_shape[-1] n_output = output_shape[0] self._h1 = nn.Linear(n_input, n_features) self._h2 = nn.Linear(n_features, n_features) self._h3 = nn.Linear(n_features, n_output) nn.init.xavier_uniform_(self._h1.weight, gain=nn.init.calculate_gain('tanh')) nn.init.xavier_uniform_(self._h2.weight, gain=nn.init.calculate_gain('tanh')) nn.init.xavier_uniform_(self._h3.weight, gain=nn.init.calculate_gain('linear'))
[docs] def forward(self, state, **kwargs): features1 = torch.tanh(self._h1(torch.squeeze(state, 1).float())) features2 = torch.tanh(self._h2(features1)) a = self._h3(features2) return a