From 7eec728142cc02305032a7e2f624f9d398b21e47 Mon Sep 17 00:00:00 2001 From: jbion Date: Tue, 20 Dec 2016 15:24:25 +0100 Subject: Add interceptor for invalid subscriptions --- .../config/TopicSubscriptionInterceptor.java | 38 +++++++++++ .../sevenwonders/config/WebSocketConfig.java | 15 +++- .../java/org/luxons/sevenwonders/game/Decks.java | 14 ++-- .../java/org/luxons/sevenwonders/game/Game.java | 4 ++ .../java/org/luxons/sevenwonders/game/Lobby.java | 4 ++ .../sevenwonders/repositories/LobbyRepository.java | 13 +++- .../validation/DestinationAccessValidator.java | 79 ++++++++++++++++++++++ 7 files changed, 159 insertions(+), 8 deletions(-) create mode 100644 src/main/java/org/luxons/sevenwonders/config/TopicSubscriptionInterceptor.java create mode 100644 src/main/java/org/luxons/sevenwonders/validation/DestinationAccessValidator.java (limited to 'src/main/java/org/luxons/sevenwonders') diff --git a/src/main/java/org/luxons/sevenwonders/config/TopicSubscriptionInterceptor.java b/src/main/java/org/luxons/sevenwonders/config/TopicSubscriptionInterceptor.java new file mode 100644 index 00000000..f8d92068 --- /dev/null +++ b/src/main/java/org/luxons/sevenwonders/config/TopicSubscriptionInterceptor.java @@ -0,0 +1,38 @@ +package org.luxons.sevenwonders.config; + +import java.security.Principal; + +import org.luxons.sevenwonders.validation.DestinationAccessValidator; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.simp.stomp.StompCommand; +import org.springframework.messaging.simp.stomp.StompHeaderAccessor; +import org.springframework.messaging.support.ChannelInterceptorAdapter; +import org.springframework.stereotype.Component; + +@Component +public class TopicSubscriptionInterceptor extends ChannelInterceptorAdapter { + + private final DestinationAccessValidator destinationAccessValidator; + + @Autowired + public TopicSubscriptionInterceptor(DestinationAccessValidator destinationAccessValidator) { + this.destinationAccessValidator = destinationAccessValidator; + } + + @Override + public Message preSend(Message message, MessageChannel channel) { + StompHeaderAccessor headerAccessor = StompHeaderAccessor.wrap(message); + if (StompCommand.SUBSCRIBE.equals(headerAccessor.getCommand())) { + Principal userPrincipal = headerAccessor.getUser(); + if (!destinationAccessValidator.hasAccess(userPrincipal.getName(), headerAccessor.getDestination())) { + throw new ForbiddenSubscriptionException(); + } + } + return message; + } + + private static class ForbiddenSubscriptionException extends RuntimeException { + } +} diff --git a/src/main/java/org/luxons/sevenwonders/config/WebSocketConfig.java b/src/main/java/org/luxons/sevenwonders/config/WebSocketConfig.java index 3b588894..d54d8da4 100644 --- a/src/main/java/org/luxons/sevenwonders/config/WebSocketConfig.java +++ b/src/main/java/org/luxons/sevenwonders/config/WebSocketConfig.java @@ -1,7 +1,9 @@ package org.luxons.sevenwonders.config; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.messaging.simp.config.ChannelRegistration; import org.springframework.messaging.simp.config.MessageBrokerRegistry; import org.springframework.web.socket.config.annotation.AbstractWebSocketMessageBrokerConfigurer; import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker; @@ -12,13 +14,20 @@ import org.springframework.web.socket.server.support.DefaultHandshakeHandler; @EnableWebSocketMessageBroker public class WebSocketConfig extends AbstractWebSocketMessageBrokerConfigurer { + private final TopicSubscriptionInterceptor topicSubscriptionInterceptor; + + @Autowired + public WebSocketConfig(TopicSubscriptionInterceptor topicSubscriptionInterceptor) { + this.topicSubscriptionInterceptor = topicSubscriptionInterceptor; + } + @Override public void configureMessageBroker(MessageBrokerRegistry config) { // prefixes for all subscriptions config.enableSimpleBroker("/queue", "/topic"); config.setUserDestinationPrefix("/user"); - // prefix for all calls from clients + // /app for normal calls, /topic for subscription events config.setApplicationDestinationPrefixes("/app", "/topic"); } @@ -35,4 +44,8 @@ public class WebSocketConfig extends AbstractWebSocketMessageBrokerConfigurer { return new AnonymousUsersHandshakeHandler(); } + @Override + public void configureClientInboundChannel(ChannelRegistration registration) { + registration.setInterceptors(topicSubscriptionInterceptor); + } } \ No newline at end of file diff --git a/src/main/java/org/luxons/sevenwonders/game/Decks.java b/src/main/java/org/luxons/sevenwonders/game/Decks.java index 12fda17f..abc8e817 100644 --- a/src/main/java/org/luxons/sevenwonders/game/Decks.java +++ b/src/main/java/org/luxons/sevenwonders/game/Decks.java @@ -9,8 +9,6 @@ import org.luxons.sevenwonders.game.cards.Card; public class Decks { - private static final int HAND_SIZE = 7; - private Map> cardsPerAge = new HashMap<>(); public Decks(Map> cardsPerAge) { @@ -23,7 +21,7 @@ public class Decks { .flatMap(List::stream) .filter(c -> c.getName().equals(cardName)) .findAny() - .orElseThrow(CardNotFoundException::new); + .orElseThrow(() -> new CardNotFoundException(cardName)); } Map> deal(int age, int nbPlayers) { @@ -41,20 +39,24 @@ public class Decks { } private void validateNbCards(List deck, int nbPlayers) { - if (nbPlayers * HAND_SIZE != deck.size()) { + if (deck.size() % nbPlayers != 0) { throw new IllegalArgumentException( - String.format("%d cards is not the expected number for %d players", deck.size(), nbPlayers)); + String.format("Cannot deal %d cards evenly between %d players", deck.size(), nbPlayers)); } } private Map> deal(List deck, int nbPlayers) { Map> hands = new HashMap<>(nbPlayers); for (int i = 0; i < deck.size(); i++) { - hands.putIfAbsent(i % nbPlayers, new ArrayList<>()).add(deck.get(i)); + hands.putIfAbsent(i % nbPlayers, new ArrayList<>()); + hands.get(i % nbPlayers).add(deck.get(i)); } return hands; } public class CardNotFoundException extends RuntimeException { + CardNotFoundException(String message) { + super(message); + } } } diff --git a/src/main/java/org/luxons/sevenwonders/game/Game.java b/src/main/java/org/luxons/sevenwonders/game/Game.java index 70a5b615..53b8bc53 100644 --- a/src/main/java/org/luxons/sevenwonders/game/Game.java +++ b/src/main/java/org/luxons/sevenwonders/game/Game.java @@ -48,6 +48,10 @@ public class Game { return table.getPlayers(); } + public boolean containsUser(String userName) { + return getPlayers().stream().anyMatch(p -> p.getUserName().equals(userName)); + } + private void startNewAge() { currentAge++; hands = decks.deal(currentAge, table.getNbPlayers()); diff --git a/src/main/java/org/luxons/sevenwonders/game/Lobby.java b/src/main/java/org/luxons/sevenwonders/game/Lobby.java index 241c5530..35f72f0f 100644 --- a/src/main/java/org/luxons/sevenwonders/game/Lobby.java +++ b/src/main/java/org/luxons/sevenwonders/game/Lobby.java @@ -86,6 +86,10 @@ public class Lobby { return owner.getUserName().equals(userName); } + public boolean containsUser(String userName) { + return players.stream().anyMatch(p -> p.getUserName().equals(userName)); + } + private static class GameAlreadyStartedException extends IllegalStateException { } diff --git a/src/main/java/org/luxons/sevenwonders/repositories/LobbyRepository.java b/src/main/java/org/luxons/sevenwonders/repositories/LobbyRepository.java index bede34af..21348890 100644 --- a/src/main/java/org/luxons/sevenwonders/repositories/LobbyRepository.java +++ b/src/main/java/org/luxons/sevenwonders/repositories/LobbyRepository.java @@ -18,6 +18,8 @@ public class LobbyRepository { private Map lobbies = new HashMap<>(); + private Map lobbiesById = new HashMap<>(); + private long lastGameId = 0; @Autowired @@ -36,6 +38,7 @@ public class LobbyRepository { long id = lastGameId++; Lobby lobby = new Lobby(id, gameName, owner, gameDefinitionLoader.getGameDefinition()); lobbies.put(gameName, lobby); + lobbiesById.put(id, lobby); return lobby; } @@ -47,7 +50,15 @@ public class LobbyRepository { return lobby; } - private static class LobbyNotFoundException extends RuntimeException { + public Lobby find(long lobbyId) { + Lobby lobby = lobbiesById.get(lobbyId); + if (lobby == null) { + throw new LobbyNotFoundException(String.valueOf(lobbyId)); + } + return lobby; + } + + public static class LobbyNotFoundException extends RuntimeException { LobbyNotFoundException(String name) { super("Lobby not found for game '" + name + "'"); } diff --git a/src/main/java/org/luxons/sevenwonders/validation/DestinationAccessValidator.java b/src/main/java/org/luxons/sevenwonders/validation/DestinationAccessValidator.java new file mode 100644 index 00000000..bc7e52ce --- /dev/null +++ b/src/main/java/org/luxons/sevenwonders/validation/DestinationAccessValidator.java @@ -0,0 +1,79 @@ +package org.luxons.sevenwonders.validation; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import org.luxons.sevenwonders.game.Game; +import org.luxons.sevenwonders.game.Lobby; +import org.luxons.sevenwonders.repositories.GameRepository; +import org.luxons.sevenwonders.repositories.LobbyRepository; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Component; + +@Component +public class DestinationAccessValidator { + + private static final Pattern lobbyDestination = Pattern.compile(".*?/lobby/(?\\d+?)(/.*)?"); + + private static final Pattern gameDestination = Pattern.compile(".*?/game/(?\\d+?)(/.*)?"); + + private final LobbyRepository lobbyRepository; + + private final GameRepository gameRepository; + + @Autowired + public DestinationAccessValidator(LobbyRepository lobbyRepository, GameRepository gameRepository) { + this.lobbyRepository = lobbyRepository; + this.gameRepository = gameRepository; + } + + public boolean hasAccess(String userName, String destination) { + if (userName == null) { + // unnamed user cannot belong to anything + return false; + } + if (hasForbiddenGameReference(userName, destination)) { + return false; + } + if (hasForbiddenLobbyReference(userName, destination)) { + return false; + } + return true; + } + + private boolean hasForbiddenGameReference(String userName, String destination) { + Matcher gameMatcher = gameDestination.matcher(destination); + if (!gameMatcher.matches()) { + return false; // no game reference is always OK + } + int gameId = extractId(gameMatcher); + return !isUserInGame(userName, gameId); + } + + private boolean hasForbiddenLobbyReference(String userName, String destination) { + Matcher lobbyMatcher = lobbyDestination.matcher(destination); + if (!lobbyMatcher.matches()) { + return false; // no lobby reference is always OK + } + int lobbyId = extractId(lobbyMatcher); + return !isUserInLobby(userName, lobbyId); + } + + private boolean isUserInGame(String userName, int gameId) { + Game game = gameRepository.find(gameId); + return game.containsUser(userName); + } + + private boolean isUserInLobby(String userName, int lobbyId) { + Lobby lobby = lobbyRepository.find(lobbyId); + return lobby.containsUser(userName); + } + + private static int extractId(Matcher matcher) { + String id = matcher.group("id"); + if (id == null) { + throw new IllegalArgumentException("No id matched in the destination"); + } + return Integer.parseInt(id); + } +} -- cgit