From 7eec728142cc02305032a7e2f624f9d398b21e47 Mon Sep 17 00:00:00 2001 From: jbion Date: Tue, 20 Dec 2016 15:24:25 +0100 Subject: Add interceptor for invalid subscriptions --- .../config/TopicSubscriptionInterceptor.java | 38 +++++++ .../sevenwonders/config/WebSocketConfig.java | 15 ++- .../java/org/luxons/sevenwonders/game/Decks.java | 14 ++- .../java/org/luxons/sevenwonders/game/Game.java | 4 + .../java/org/luxons/sevenwonders/game/Lobby.java | 4 + .../sevenwonders/repositories/LobbyRepository.java | 13 ++- .../validation/DestinationAccessValidator.java | 79 +++++++++++++ .../validation/DestinationAccessValidatorTest.java | 125 +++++++++++++++++++++ 8 files changed, 284 insertions(+), 8 deletions(-) create mode 100644 src/main/java/org/luxons/sevenwonders/config/TopicSubscriptionInterceptor.java create mode 100644 src/main/java/org/luxons/sevenwonders/validation/DestinationAccessValidator.java create mode 100644 src/test/java/org/luxons/sevenwonders/validation/DestinationAccessValidatorTest.java (limited to 'src') diff --git a/src/main/java/org/luxons/sevenwonders/config/TopicSubscriptionInterceptor.java b/src/main/java/org/luxons/sevenwonders/config/TopicSubscriptionInterceptor.java new file mode 100644 index 00000000..f8d92068 --- /dev/null +++ b/src/main/java/org/luxons/sevenwonders/config/TopicSubscriptionInterceptor.java @@ -0,0 +1,38 @@ +package org.luxons.sevenwonders.config; + +import java.security.Principal; + +import org.luxons.sevenwonders.validation.DestinationAccessValidator; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.messaging.Message; +import org.springframework.messaging.MessageChannel; +import org.springframework.messaging.simp.stomp.StompCommand; +import org.springframework.messaging.simp.stomp.StompHeaderAccessor; +import org.springframework.messaging.support.ChannelInterceptorAdapter; +import org.springframework.stereotype.Component; + +@Component +public class TopicSubscriptionInterceptor extends ChannelInterceptorAdapter { + + private final DestinationAccessValidator destinationAccessValidator; + + @Autowired + public TopicSubscriptionInterceptor(DestinationAccessValidator destinationAccessValidator) { + this.destinationAccessValidator = destinationAccessValidator; + } + + @Override + public Message preSend(Message message, MessageChannel channel) { + StompHeaderAccessor headerAccessor = StompHeaderAccessor.wrap(message); + if (StompCommand.SUBSCRIBE.equals(headerAccessor.getCommand())) { + Principal userPrincipal = headerAccessor.getUser(); + if (!destinationAccessValidator.hasAccess(userPrincipal.getName(), headerAccessor.getDestination())) { + throw new ForbiddenSubscriptionException(); + } + } + return message; + } + + private static class ForbiddenSubscriptionException extends RuntimeException { + } +} diff --git a/src/main/java/org/luxons/sevenwonders/config/WebSocketConfig.java b/src/main/java/org/luxons/sevenwonders/config/WebSocketConfig.java index 3b588894..d54d8da4 100644 --- a/src/main/java/org/luxons/sevenwonders/config/WebSocketConfig.java +++ b/src/main/java/org/luxons/sevenwonders/config/WebSocketConfig.java @@ -1,7 +1,9 @@ package org.luxons.sevenwonders.config; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.messaging.simp.config.ChannelRegistration; import org.springframework.messaging.simp.config.MessageBrokerRegistry; import org.springframework.web.socket.config.annotation.AbstractWebSocketMessageBrokerConfigurer; import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker; @@ -12,13 +14,20 @@ import org.springframework.web.socket.server.support.DefaultHandshakeHandler; @EnableWebSocketMessageBroker public class WebSocketConfig extends AbstractWebSocketMessageBrokerConfigurer { + private final TopicSubscriptionInterceptor topicSubscriptionInterceptor; + + @Autowired + public WebSocketConfig(TopicSubscriptionInterceptor topicSubscriptionInterceptor) { + this.topicSubscriptionInterceptor = topicSubscriptionInterceptor; + } + @Override public void configureMessageBroker(MessageBrokerRegistry config) { // prefixes for all subscriptions config.enableSimpleBroker("/queue", "/topic"); config.setUserDestinationPrefix("/user"); - // prefix for all calls from clients + // /app for normal calls, /topic for subscription events config.setApplicationDestinationPrefixes("/app", "/topic"); } @@ -35,4 +44,8 @@ public class WebSocketConfig extends AbstractWebSocketMessageBrokerConfigurer { return new AnonymousUsersHandshakeHandler(); } + @Override + public void configureClientInboundChannel(ChannelRegistration registration) { + registration.setInterceptors(topicSubscriptionInterceptor); + } } \ No newline at end of file diff --git a/src/main/java/org/luxons/sevenwonders/game/Decks.java b/src/main/java/org/luxons/sevenwonders/game/Decks.java index 12fda17f..abc8e817 100644 --- a/src/main/java/org/luxons/sevenwonders/game/Decks.java +++ b/src/main/java/org/luxons/sevenwonders/game/Decks.java @@ -9,8 +9,6 @@ import org.luxons.sevenwonders.game.cards.Card; public class Decks { - private static final int HAND_SIZE = 7; - private Map> cardsPerAge = new HashMap<>(); public Decks(Map> cardsPerAge) { @@ -23,7 +21,7 @@ public class Decks { .flatMap(List::stream) .filter(c -> c.getName().equals(cardName)) .findAny() - .orElseThrow(CardNotFoundException::new); + .orElseThrow(() -> new CardNotFoundException(cardName)); } Map> deal(int age, int nbPlayers) { @@ -41,20 +39,24 @@ public class Decks { } private void validateNbCards(List deck, int nbPlayers) { - if (nbPlayers * HAND_SIZE != deck.size()) { + if (deck.size() % nbPlayers != 0) { throw new IllegalArgumentException( - String.format("%d cards is not the expected number for %d players", deck.size(), nbPlayers)); + String.format("Cannot deal %d cards evenly between %d players", deck.size(), nbPlayers)); } } private Map> deal(List deck, int nbPlayers) { Map> hands = new HashMap<>(nbPlayers); for (int i = 0; i < deck.size(); i++) { - hands.putIfAbsent(i % nbPlayers, new ArrayList<>()).add(deck.get(i)); + hands.putIfAbsent(i % nbPlayers, new ArrayList<>()); + hands.get(i % nbPlayers).add(deck.get(i)); } return hands; } public class CardNotFoundException extends RuntimeException { + CardNotFoundException(String message) { + super(message); + } } } diff --git a/src/main/java/org/luxons/sevenwonders/game/Game.java b/src/main/java/org/luxons/sevenwonders/game/Game.java index 70a5b615..53b8bc53 100644 --- a/src/main/java/org/luxons/sevenwonders/game/Game.java +++ b/src/main/java/org/luxons/sevenwonders/game/Game.java @@ -48,6 +48,10 @@ public class Game { return table.getPlayers(); } + public boolean containsUser(String userName) { + return getPlayers().stream().anyMatch(p -> p.getUserName().equals(userName)); + } + private void startNewAge() { currentAge++; hands = decks.deal(currentAge, table.getNbPlayers()); diff --git a/src/main/java/org/luxons/sevenwonders/game/Lobby.java b/src/main/java/org/luxons/sevenwonders/game/Lobby.java index 241c5530..35f72f0f 100644 --- a/src/main/java/org/luxons/sevenwonders/game/Lobby.java +++ b/src/main/java/org/luxons/sevenwonders/game/Lobby.java @@ -86,6 +86,10 @@ public class Lobby { return owner.getUserName().equals(userName); } + public boolean containsUser(String userName) { + return players.stream().anyMatch(p -> p.getUserName().equals(userName)); + } + private static class GameAlreadyStartedException extends IllegalStateException { } diff --git a/src/main/java/org/luxons/sevenwonders/repositories/LobbyRepository.java b/src/main/java/org/luxons/sevenwonders/repositories/LobbyRepository.java index bede34af..21348890 100644 --- a/src/main/java/org/luxons/sevenwonders/repositories/LobbyRepository.java +++ b/src/main/java/org/luxons/sevenwonders/repositories/LobbyRepository.java @@ -18,6 +18,8 @@ public class LobbyRepository { private Map lobbies = new HashMap<>(); + private Map lobbiesById = new HashMap<>(); + private long lastGameId = 0; @Autowired @@ -36,6 +38,7 @@ public class LobbyRepository { long id = lastGameId++; Lobby lobby = new Lobby(id, gameName, owner, gameDefinitionLoader.getGameDefinition()); lobbies.put(gameName, lobby); + lobbiesById.put(id, lobby); return lobby; } @@ -47,7 +50,15 @@ public class LobbyRepository { return lobby; } - private static class LobbyNotFoundException extends RuntimeException { + public Lobby find(long lobbyId) { + Lobby lobby = lobbiesById.get(lobbyId); + if (lobby == null) { + throw new LobbyNotFoundException(String.valueOf(lobbyId)); + } + return lobby; + } + + public static class LobbyNotFoundException extends RuntimeException { LobbyNotFoundException(String name) { super("Lobby not found for game '" + name + "'"); } diff --git a/src/main/java/org/luxons/sevenwonders/validation/DestinationAccessValidator.java b/src/main/java/org/luxons/sevenwonders/validation/DestinationAccessValidator.java new file mode 100644 index 00000000..bc7e52ce --- /dev/null +++ b/src/main/java/org/luxons/sevenwonders/validation/DestinationAccessValidator.java @@ -0,0 +1,79 @@ +package org.luxons.sevenwonders.validation; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import org.luxons.sevenwonders.game.Game; +import org.luxons.sevenwonders.game.Lobby; +import org.luxons.sevenwonders.repositories.GameRepository; +import org.luxons.sevenwonders.repositories.LobbyRepository; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Component; + +@Component +public class DestinationAccessValidator { + + private static final Pattern lobbyDestination = Pattern.compile(".*?/lobby/(?\\d+?)(/.*)?"); + + private static final Pattern gameDestination = Pattern.compile(".*?/game/(?\\d+?)(/.*)?"); + + private final LobbyRepository lobbyRepository; + + private final GameRepository gameRepository; + + @Autowired + public DestinationAccessValidator(LobbyRepository lobbyRepository, GameRepository gameRepository) { + this.lobbyRepository = lobbyRepository; + this.gameRepository = gameRepository; + } + + public boolean hasAccess(String userName, String destination) { + if (userName == null) { + // unnamed user cannot belong to anything + return false; + } + if (hasForbiddenGameReference(userName, destination)) { + return false; + } + if (hasForbiddenLobbyReference(userName, destination)) { + return false; + } + return true; + } + + private boolean hasForbiddenGameReference(String userName, String destination) { + Matcher gameMatcher = gameDestination.matcher(destination); + if (!gameMatcher.matches()) { + return false; // no game reference is always OK + } + int gameId = extractId(gameMatcher); + return !isUserInGame(userName, gameId); + } + + private boolean hasForbiddenLobbyReference(String userName, String destination) { + Matcher lobbyMatcher = lobbyDestination.matcher(destination); + if (!lobbyMatcher.matches()) { + return false; // no lobby reference is always OK + } + int lobbyId = extractId(lobbyMatcher); + return !isUserInLobby(userName, lobbyId); + } + + private boolean isUserInGame(String userName, int gameId) { + Game game = gameRepository.find(gameId); + return game.containsUser(userName); + } + + private boolean isUserInLobby(String userName, int lobbyId) { + Lobby lobby = lobbyRepository.find(lobbyId); + return lobby.containsUser(userName); + } + + private static int extractId(Matcher matcher) { + String id = matcher.group("id"); + if (id == null) { + throw new IllegalArgumentException("No id matched in the destination"); + } + return Integer.parseInt(id); + } +} diff --git a/src/test/java/org/luxons/sevenwonders/validation/DestinationAccessValidatorTest.java b/src/test/java/org/luxons/sevenwonders/validation/DestinationAccessValidatorTest.java new file mode 100644 index 00000000..1ae0b3fc --- /dev/null +++ b/src/test/java/org/luxons/sevenwonders/validation/DestinationAccessValidatorTest.java @@ -0,0 +1,125 @@ +package org.luxons.sevenwonders.validation; + +import org.junit.Before; +import org.junit.Test; +import org.luxons.sevenwonders.game.Game; +import org.luxons.sevenwonders.game.Lobby; +import org.luxons.sevenwonders.game.Player; +import org.luxons.sevenwonders.game.Settings; +import org.luxons.sevenwonders.game.data.GameDefinitionLoader; +import org.luxons.sevenwonders.repositories.GameRepository; +import org.luxons.sevenwonders.repositories.LobbyRepository; +import org.luxons.sevenwonders.repositories.LobbyRepository.LobbyNotFoundException; + +import static org.junit.Assert.*; + +public class DestinationAccessValidatorTest { + + private LobbyRepository lobbyRepository; + + private GameRepository gameRepository; + + private DestinationAccessValidator destinationAccessValidator; + + @Before + public void setup() { + gameRepository = new GameRepository(); + lobbyRepository = new LobbyRepository(new GameDefinitionLoader()); + destinationAccessValidator = new DestinationAccessValidator(lobbyRepository, gameRepository); + } + + private Lobby createLobby(String gameName, String ownerUserName, String... otherPlayers) { + Player owner = new Player(ownerUserName, ownerUserName); + Lobby lobby = lobbyRepository.create(gameName, owner); + for (String playerName : otherPlayers) { + Player player = new Player(playerName, playerName); + lobby.addPlayer(player); + } + return lobby; + } + + private Game createGame(String gameName, String ownerUserName, String... otherPlayers) { + Lobby lobby = createLobby(gameName, ownerUserName, otherPlayers); + Game game = lobby.startGame(new Settings()); + gameRepository.add(game); + return game; + } + + @Test + public void validate_successWhenNoReference() { + assertTrue(destinationAccessValidator.hasAccess("", "")); + assertTrue(destinationAccessValidator.hasAccess("", "test")); + assertTrue(destinationAccessValidator.hasAccess("testUser", "test")); + } + + @Test + public void validate_successWhenNoRefFollows() { + assertTrue(destinationAccessValidator.hasAccess("testUser", "/game/")); + assertTrue(destinationAccessValidator.hasAccess("testUser", "/lobby/")); + assertTrue(destinationAccessValidator.hasAccess("testUser", "prefix/game/")); + assertTrue(destinationAccessValidator.hasAccess("testUser", "prefix/lobby/")); + assertTrue(destinationAccessValidator.hasAccess("testUser", "/game//suffix")); + assertTrue(destinationAccessValidator.hasAccess("testUser", "/lobby//suffix")); + } + + @Test + public void validate_successWhenRefIsNotANumber() { + assertTrue(destinationAccessValidator.hasAccess("testUser", "/game/notANumber")); + assertTrue(destinationAccessValidator.hasAccess("testUser", "/lobby/notANumber")); + assertTrue(destinationAccessValidator.hasAccess("testUser", "prefix/game/notANumber")); + assertTrue(destinationAccessValidator.hasAccess("testUser", "prefix/lobby/notANumber")); + assertTrue(destinationAccessValidator.hasAccess("testUser", "/game/notANumber/suffix")); + assertTrue(destinationAccessValidator.hasAccess("testUser", "/lobby/notANumber/suffix")); + } + + @Test(expected = LobbyNotFoundException.class) + public void validate_failWhenNoLobbyExist() { + destinationAccessValidator.hasAccess("", "/lobby/0"); + } + + @Test(expected = LobbyNotFoundException.class) + public void validate_failWhenReferencedLobbyDoesNotExist() { + createLobby("Test Game", "ownerUser1"); + createLobby("Test Game 2", "ownerUser2"); + destinationAccessValidator.hasAccess("", "/lobby/3"); + } + + @Test + public void validate_failWhenUserIsNotPartOfReferencedLobby() { + createLobby("Test Game", "ownerUser"); + destinationAccessValidator.hasAccess("", "/lobby/0"); + } + + @Test + public void validate_successWhenUserIsOwnerOfReferencedLobby() { + createLobby("Test Game 1", "user1"); + assertTrue(destinationAccessValidator.hasAccess("user1", "/lobby/0")); + createLobby("Test Game 2", "user2"); + assertTrue(destinationAccessValidator.hasAccess("user2", "/lobby/1")); + } + + @Test + public void validate_successWhenUserIsMemberOfReferencedLobby() { + createLobby("Test Game 1", "user1", "user2"); + assertTrue(destinationAccessValidator.hasAccess("user2", "/lobby/0")); + createLobby("Test Game 2", "user3", "user4"); + assertTrue(destinationAccessValidator.hasAccess("user4", "/lobby/1")); + } + + @Test + public void validate_successWhenUserIsMemberOfReferencedGame() { + createGame("Test Game 1", "user1", "user2", "user3"); + assertTrue(destinationAccessValidator.hasAccess("user2", "/game/0")); + createGame("Test Game 2", "user4", "user5", "user6"); + assertTrue(destinationAccessValidator.hasAccess("user6", "/game/1")); + } + + @Test + public void validate_failsWhenUserPartOfReferencedGame() { +// lobbyRepository.create("Test Game Name"); +// gameRepository.add(); + assertTrue(destinationAccessValidator.hasAccess("", "/game/notAnId")); + assertTrue(destinationAccessValidator.hasAccess("", "/lobby/notAnId")); + } + +} \ No newline at end of file -- cgit