import torch
import torch.nn as nn
import torch.nn.functional as F
[docs]
class TRPONetwork(nn.Module):
[docs]
def __init__(self, input_shape, output_shape, n_features, **kwargs):
super(TRPONetwork, self).__init__()
n_input = input_shape[-1]
n_output = output_shape[0]
self._h1 = nn.Linear(n_input, n_features)
self._h2 = nn.Linear(n_features, n_features)
self._h3 = nn.Linear(n_features, n_output)
nn.init.xavier_uniform_(self._h1.weight,
gain=nn.init.calculate_gain('relu'))
nn.init.xavier_uniform_(self._h2.weight,
gain=nn.init.calculate_gain('relu'))
nn.init.xavier_uniform_(self._h3.weight,
gain=nn.init.calculate_gain('linear'))
[docs]
def forward(self, state, **kwargs):
features1 = F.relu(self._h1(torch.squeeze(state, 1).float()))
features2 = F.relu(self._h2(features1))
a = self._h3(features2)
return a