Skip to content

Instantly share code, notes, and snippets.

@maxsap
Last active April 10, 2020 08:55
Show Gist options
  • Star 7 You must be signed in to star a gist
  • Fork 4 You must be signed in to fork a gist
  • Save maxsap/da154f99a2dbb414471a to your computer and use it in GitHub Desktop.
Save maxsap/da154f99a2dbb414471a to your computer and use it in GitHub Desktop.
Spring Security OAuth2 programmatic configuration.
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.apache.commons.lang3.SerializationUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.security.jwt.JwtHelper;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.common.OAuth2RefreshToken;
import org.springframework.security.oauth2.common.exceptions.InvalidTokenException;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.security.oauth2.provider.token.AuthenticationKeyGenerator;
import org.springframework.security.oauth2.provider.token.DefaultAuthenticationKeyGenerator;
import org.springframework.security.oauth2.provider.token.TokenStore;
import org.springframework.transaction.annotation.Transactional;
import persistence.dao.IOAuthAccessToken;
import persistence.dao.IOAuthRefreshToken;
import persistence.model.JWTCommon;
import persistence.model.OAuthAccessToken;
import persistence.model.OAuthRefreshToken;
import persistence.model.User;
import utils.Serializer;
/*
*Serializer in this class is a GSON serializer, can be replaced with any other JSON
*serializer implementation.
*/
@Transactional(value="restTransactionManager", rollbackFor = Exception.class)
public class CustomJwtTokenStore implements TokenStore {
@Autowired
private IOAuthAccessToken accessTokenDao;
@Autowired
private IOAuthRefreshToken refreshTokenDao;
private final AuthenticationKeyGenerator authenticationKeyGenerator = new DefaultAuthenticationKeyGenerator();
@Override
public OAuth2Authentication readAuthentication(OAuth2AccessToken paramOAuth2AccessToken) {
OAuthAccessToken token = readAuthenticationFromDB(paramOAuth2AccessToken.getValue());
OAuth2Authentication token2 = SerializationUtils.deserialize(token.getAuthentication());
User loggedUser = (User) token2.getPrincipal();
// make super sure that the token hasnt been in any way hijacked and belongs to the loggedin user!!!
if(!token.getUser().getEmail().equals(loggedUser.getEmail()))
throw new InvalidTokenException("Invalid access token");
return token2;
}
private OAuthAccessToken readAuthenticationFromDB(String value) {
JWTCommon common = extractJtiFromRefreshToken(value);
OAuthAccessToken storedObject = accessTokenDao.findByTokenId(common.getJti());
if(storedObject == null)
return null;
return storedObject;
}
@Override
public OAuth2Authentication readAuthentication(String paramString) {
OAuthAccessToken storedObject = readAuthenticationFromDB(paramString);
return SerializationUtils.deserialize(storedObject.getAuthentication());
}
@Override
public void storeAccessToken(OAuth2AccessToken paramOAuth2AccessToken, OAuth2Authentication paramOAuth2Authentication) {
JWTCommon accessJTI = Serializer.createFromJson(JWTCommon.class, JwtHelper.decode(paramOAuth2AccessToken.getValue()).getClaims());
JWTCommon refreshJTI = extractJtiFromRefreshToken(paramOAuth2AccessToken.getRefreshToken().getValue());
OAuthAccessToken accessToken = new OAuthAccessToken(paramOAuth2AccessToken, paramOAuth2Authentication, authenticationKeyGenerator.extractKey(paramOAuth2Authentication), accessJTI.getJti(), refreshJTI.getJti());
accessTokenDao.saveOrUpdate(accessToken);
}
/**
* Refresh token's JTI is a valid UUID only when the token hasn't been refreshed.
* In any other case the JTI is the actual JWT signature of the old token.
* This method is traversing recursivelly through the JTI field until it finds the original UUID
* @param original the string representation for the JWT token.
* @return token's JTI value
*/
private JWTCommon extractJtiFromRefreshToken(String original) {
JWTCommon result = null;
try {
result = Serializer.createFromJson(JWTCommon.class, JwtHelper.decode(original).getClaims());
// this is not a valid UUID traverse
while(result.getJti().length() > 36) {
result = extractJtiFromRefreshToken(result.getJti());
}
} catch(Exception e) {
result = new JWTCommon();
result.setJti(original);
}
return result;
}
@Override
public OAuth2AccessToken readAccessToken(String paramString) {
JWTCommon common = extractJtiFromRefreshToken(paramString);
OAuthAccessToken storedObject = accessTokenDao.findByTokenId(common.getJti());
if(storedObject == null)
return null;
Object authentication = SerializationUtils.deserialize(storedObject.getToken());
return (OAuth2AccessToken) authentication;
}
@Override
public void removeAccessToken(OAuth2AccessToken paramOAuth2AccessToken) {
JWTCommon common = extractJtiFromRefreshToken(paramOAuth2AccessToken.getValue());
OAuthAccessToken storedObject = accessTokenDao.findByTokenId(common.getJti());
if(storedObject != null)
accessTokenDao.delete(storedObject);
}
@Override
public void storeRefreshToken(OAuth2RefreshToken paramOAuth2RefreshToken, OAuth2Authentication paramOAuth2Authentication) {
JWTCommon common = extractJtiFromRefreshToken(paramOAuth2RefreshToken.getValue());
refreshTokenDao.saveOrUpdate(new OAuthRefreshToken(paramOAuth2RefreshToken, paramOAuth2Authentication, common.getJti()));
}
@Override
public OAuth2RefreshToken readRefreshToken(String paramString) {
JWTCommon common = extractJtiFromRefreshToken(paramString);
OAuthRefreshToken refreshEntity = refreshTokenDao.findByTokenId(common.getJti());
if(refreshEntity == null)
return null;
return SerializationUtils.deserialize(refreshEntity.getoAuth2RefreshToken());
}
@Override
public OAuth2Authentication readAuthenticationForRefreshToken(OAuth2RefreshToken paramOAuth2RefreshToken) {
JWTCommon common = extractJtiFromRefreshToken(paramOAuth2RefreshToken.getValue());
OAuthRefreshToken storedObject = refreshTokenDao.findByTokenId(common.getJti());
if(storedObject == null)
return null;
return SerializationUtils.deserialize(storedObject.getAuthentication());
}
@Override
public void removeRefreshToken(OAuth2RefreshToken paramOAuth2RefreshToken) {
JWTCommon common = extractJtiFromRefreshToken(paramOAuth2RefreshToken.getValue());
OAuthRefreshToken storedObject = refreshTokenDao.findByTokenId(common.getJti());
if(storedObject != null)
refreshTokenDao.delete(storedObject);
}
@Override
public void removeAccessTokenUsingRefreshToken(OAuth2RefreshToken paramOAuth2RefreshToken) {
JWTCommon common = extractJtiFromRefreshToken(paramOAuth2RefreshToken.getValue());
OAuthAccessToken storedToken = accessTokenDao.findByRefreshToken(common.getJti());
if(storedToken != null)
accessTokenDao.delete(storedToken);
}
@Override
public OAuth2AccessToken getAccessToken(OAuth2Authentication paramOAuth2Authentication) {
OAuthAccessToken storedObject = accessTokenDao.findByAuthenticationId(authenticationKeyGenerator.extractKey(paramOAuth2Authentication));
if(storedObject == null)
return null;
Object authentication = SerializationUtils.deserialize(storedObject.getToken());
return (OAuth2AccessToken) authentication;
}
@Override
public Collection<OAuth2AccessToken> findTokensByClientIdAndUserName(String paramString1, String paramString2) {
List<OAuthAccessToken> result = accessTokenDao.findByClientIdAndUserName(paramString1, paramString2);
List<OAuth2AccessToken> oauthAccTokens = new ArrayList<>();
for(OAuthAccessToken token : result) {
oauthAccTokens.add((OAuth2AccessToken) SerializationUtils.deserialize(token.getoAuth2AccessToken()));
}
return oauthAccTokens;
}
@Override
public Collection<OAuth2AccessToken> findTokensByClientId(String paramString) {
List<OAuthAccessToken> result = accessTokenDao.findByClientId(paramString);
List<OAuth2AccessToken> oauthAccTokens = new ArrayList<>();
for(OAuthAccessToken token : result) {
oauthAccTokens.add((OAuth2AccessToken) SerializationUtils.deserialize(token.getoAuth2AccessToken()));
}
return oauthAccTokens;
}
}
public class JWTCommon {
private Long exp;
private String jti;
private String client_id;
private List<UserRole> authorities;
private String user_name;
// getter setters
}
import javax.sql.DataSource;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.ComponentScan;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Import;
import org.springframework.security.authentication.AuthenticationManager;
import org.springframework.security.oauth2.config.annotation.configurers.ClientDetailsServiceConfigurer;
import org.springframework.security.oauth2.config.annotation.web.configuration.AuthorizationServerConfigurerAdapter;
import org.springframework.security.oauth2.config.annotation.web.configuration.EnableAuthorizationServer;
import org.springframework.security.oauth2.config.annotation.web.configuration.EnableResourceServer;
import org.springframework.security.oauth2.config.annotation.web.configurers.AuthorizationServerEndpointsConfigurer;
import org.springframework.security.oauth2.config.annotation.web.configurers.AuthorizationServerSecurityConfigurer;
import org.springframework.security.oauth2.provider.token.store.JwtAccessTokenConverter;
import services.impl.CustomJwtTokenStore;
@Configuration
@ComponentScan
@EnableResourceServer
@Import({SecurityConfig.class})
public class OAuth2ServerConfig {
@Configuration
@EnableAuthorizationServer
protected static class OAuth2Config extends AuthorizationServerConfigurerAdapter {
@Autowired
private AuthenticationManager authenticationManager;
@Autowired
private GlobalConfigurations globalConfigs;
@Autowired
@Qualifier("restDataSource")
private DataSource datasource;
@Bean
public JwtAccessTokenConverter accessTokenConverter() {
return new JwtAccessTokenConverter();
}
@Bean
public CustomJwtTokenStore tokenStore() {
return new CustomJwtTokenStore();
}
@Override
public void configure(AuthorizationServerSecurityConfigurer oauthServer) throws Exception {
oauthServer.tokenKeyAccess("isAnonymous() || hasAuthority('ROLE_TRUSTED_CLIENT')").checkTokenAccess(
"hasAuthority('ROLE_TRUSTED_CLIENT')");
}
@Override
public void configure(AuthorizationServerEndpointsConfigurer endpoints) throws Exception {
endpoints.authenticationManager(authenticationManager).tokenStore(tokenStore()).accessTokenConverter(accessTokenConverter());
}
@Override
public void configure(ClientDetailsServiceConfigurer clients) throws Exception {
clients.inMemory()
.withClient("my-trusted-client")
.resourceIds("test")
.authorizedGrantTypes("password", "authorization_code", "refresh_token", "implicit")
.authorities("ROLE_CLIENT", "ROLE_TRUSTED_CLIENT")
.scopes("read", "write", "trust", "update")
.accessTokenValiditySeconds(11)
.refreshTokenValiditySeconds(11)
.secret("secret");
}
}
}
@Entity
@Table(indexes = { @Index(name = "OAuthAccessTokenIndexing", columnList = "tokenId, authenticationId, userName, clientId") })
@DynamicInsert
@DynamicUpdate
public class OAuthAccessToken implements Serializable {
private static final long serialVersionUID = -7945135597484875770L;
@Id
@Column(unique = true, nullable = false, insertable = false, updatable = false, length = 256)
@Untouchable
private String tokenId;
@Column(nullable = false, length=256)
private String authenticationId;
@Column(nullable = false, length=256)
private String userName;
@ManyToOne(fetch = FetchType.EAGER)
private User user;
@Column(nullable = false, length=256)
private String clientId;
@Column(nullable = false, length=256)
private String refreshToken;
@Lob
@Column(nullable = false)
private byte[] oAuth2AccessToken;
@Lob
@Column(nullable = false)
private byte[] oauth2Request;
@Lob
@Column(nullable = false)
private byte[] authentication;
public byte[] getToken() {
return oAuth2AccessToken;
}
public OAuthAccessToken() {
}
public OAuthAccessToken(final OAuth2AccessToken oAuth2AccessToken, final OAuth2Authentication authentication, final String authenticationId, String jti, String refreshJTI) {
this.tokenId = jti;
this.oAuth2AccessToken = SerializationUtils.serialize(oAuth2AccessToken);
this.authenticationId = authenticationId;
this.userName = authentication.getName();
this.oauth2Request = SerializationUtils.serialize(SerializationUtils.serialize(authentication.getOAuth2Request()));
this.clientId = authentication.getOAuth2Request().getClientId();
this.authentication = SerializationUtils.serialize(authentication);
this.refreshToken = refreshJTI;
this.user = (User) authentication.getUserAuthentication().getPrincipal();
}
// getter setters
}
@Entity
@Table(indexes = { @Index(name = "OAuthRefreshTokenIndexing", columnList = "tokenId") })
@DynamicInsert
@DynamicUpdate
public class OAuthRefreshToken implements Serializable{
private static final long serialVersionUID = -917207147028399813L;
@Id
@Column(unique = true, nullable = false, insertable = false, updatable = false, length = 256)
@Untouchable
private String tokenId;
@Lob
@Column(nullable = false)
private byte[] oAuth2RefreshToken;
@Lob
@Column(nullable = false)
private byte[] authentication;
public OAuthRefreshToken() {
//KEEP HIBERNATE HAPPY
}
public OAuthRefreshToken(OAuth2RefreshToken oAuth2RefreshToken, OAuth2Authentication authentication, String jti) {
this.oAuth2RefreshToken = SerializationUtils.serialize((Serializable) oAuth2RefreshToken);
this.authentication = SerializationUtils.serialize((Serializable) authentication);
this.tokenId = jti;
}
//getters setters
}
<!-- THIS DEPENDENCIES ARE NEEDED -->
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-web</artifactId>
<version>4.0.3.RELEASE</version>
</dependency>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-tx</artifactId>
<version>4.0.3.RELEASE</version>
</dependency>
<dependency>
<groupId>org.glassfish.jersey.ext</groupId>
<artifactId>jersey-spring3</artifactId>
<version>2.7</version>
</dependency>
<dependency>
<groupId>org.glassfish.jersey.ext</groupId>
<artifactId>jersey-bean-validation</artifactId>
<version>2.7</version>
</dependency>
<dependency>
<groupId>org.springframework.security</groupId>
<artifactId>spring-security-web</artifactId>
<version>3.2.3.RELEASE</version>
</dependency>
<dependency>
<groupId>org.springframework.security</groupId>
<artifactId>spring-security-config</artifactId>
<version>3.2.3.RELEASE</version>
</dependency>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-context</artifactId>
<version>4.0.3.RELEASE</version>
</dependency>
<dependency>
<groupId>org.aspectj</groupId>
<artifactId>aspectjweaver</artifactId>
<version>1.7.0</version>
</dependency>
<dependency>
<groupId>org.aspectj</groupId>
<artifactId>aspectjrt</artifactId>
<version>1.6.12</version>
</dependency>
<dependency>
<groupId>org.aspectj</groupId>
<artifactId>aspectjtools</artifactId>
<version>1.6.12</version>
</dependency>
<dependency>
<groupId>org.projectreactor</groupId>
<artifactId>reactor-spring</artifactId>
<version>1.0.1.RELEASE</version>
</dependency>
<dependency>
<groupId>org.springframework.security.oauth</groupId>
<artifactId>spring-security-oauth2</artifactId>
<version>2.0.2.RELEASE</version>
</dependency>
<dependency>
<groupId>org.springframework.security</groupId>
<artifactId>spring-security-jwt</artifactId>
<version>1.0.2.RELEASE</version>
</dependency>
@Configuration
@EnableWebSecurity
public class SecurityConfig extends WebSecurityConfigurerAdapter {
@Autowired
private AuthenticationEntryPoint authenticationEntryPoint;
@Autowired
private AccessDeniedHandler accessDeniedHandler;
@Override
public void configure(WebSecurity web) throws Exception {
web.ignoring().antMatchers("/webjars/**", "/images/**", "/oauth/uncache_approvals", "/oauth/cache_approvals", "ANY_OTHER_ENDPOINT_TO_EXCLUDE");
}
@Override
@Bean
public AuthenticationManager authenticationManagerBean() throws Exception {
return super.authenticationManagerBean();
}
@Override
protected void configure(AuthenticationManagerBuilder auth) throws Exception {
// THIS IS USING A CUSTOM USER DETAILS SERVICE
auth.userDetailsService(userService()).passwordEncoder(passwordEncoder());
}
@Override
protected void configure(HttpSecurity http) throws Exception {
http
.anonymous().disable()
.sessionManagement()
.sessionCreationPolicy(SessionCreationPolicy.STATELESS)
.and()
.exceptionHandling()
.accessDeniedHandler(accessDeniedHandler) // handle access denied in general (for example comming from @PreAuthorization
.authenticationEntryPoint(authenticationEntryPoint) // handle authentication exceptions for unauthorized calls.
.and()
.requestMatchers().antMatchers("ACTUAL API ENDPOINTS GO HERE")
.and()
.authorizeRequests()
.antMatchers(HttpMethod.POST, ""ACTUAL API ENDPOINTS GO HERE"").permitAll()
.and()
.csrf().disable();
}
@Bean
@Autowired
ApplicationListener<AbstractAuthenticationEvent> loggerBean() {
return new org.springframework.security.authentication.event.LoggerListener();
}
@Bean
@Autowired
AccessDeniedHandler accessDeniedHandler() {
return new AccessDeniedExceptionHandler();
}
@Bean
@Autowired
AuthenticationEntryPoint entryPointBean() {
return new UnauthorizedEntryPoint();
}
@Bean
public BCryptPasswordEncoder passwordEncoder() {
return new BCryptPasswordEncoder();
}
@Bean
public UserService userService() {
return new UserServiceImpl();
}
}
<?xml version="1.0" encoding="UTF-8"?>
<web-app xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xmlns="http://java.sun.com/xml/ns/javaee"
xsi:schemaLocation="http://java.sun.com/xml/ns/javaee http://java.sun.com/xml/ns/javaee/web-app_3_0.xsd"
version="3.0">
<display-name>TEST</display-name>
<context-param>
<param-name>contextClass</param-name>
<param-value>
org.springframework.web.context.support.AnnotationConfigWebApplicationContext
</param-value>
</context-param>
<context-param>
<param-name>contextConfigLocation</param-name>
<param-value>path_to_config</param-value>
</context-param>
<listener>
<listener-class>org.springframework.web.context.ContextLoaderListener</listener-class>
</listener>
<listener>
<listener-class>org.springframework.web.context.request.RequestContextListener</listener-class>
</listener>
<!-- Jersey Servlet-->
<servlet>
<servlet-name>SpringApplication</servlet-name>
<servlet-class>org.glassfish.jersey.servlet.ServletContainer</servlet-class>
<init-param>
<param-name>javax.ws.rs.Application</param-name>
<param-value>PATH_TO_JerseyConfig</param-value>
</init-param>
<load-on-startup>1</load-on-startup>
</servlet>
<!-- Spring MVC Servlet-->
<servlet>
<servlet-name>spring</servlet-name>
<servlet-class>org.springframework.web.servlet.DispatcherServlet</servlet-class>
<init-param>
<param-name>contextConfigLocation</param-name>
<param-value>/WEB-INF/spring/appservlet/servlet-context.xml</param-value>
</init-param>
<load-on-startup>2</load-on-startup>
</servlet>
<!-- Set up Spring security filter chain -->
<filter>
<filter-name>springSecurityFilterChain</filter-name>
<filter-class>org.springframework.web.filter.DelegatingFilterProxy</filter-class>
<init-param>
<param-name>contextAttribute</param-name>
<!-- Load on same Security context (root) -->
<param-value>org.springframework.web.servlet.FrameworkServlet.CONTEXT.spring</param-value>
</init-param>
</filter>
<!-- Force UTF-8 -->
<filter>
<filter-name>CharacterEncodingFilter</filter-name>
<filter-class>org.springframework.web.filter.CharacterEncodingFilter</filter-class>
<init-param>
<param-name>encoding</param-name>
<param-value>UTF-8</param-value>
</init-param>
</filter>
<filter-mapping>
<filter-name>springSecurityFilterChain</filter-name>
<url-pattern>/*</url-pattern>
</filter-mapping>
<filter-mapping>
<filter-name>CharacterEncodingFilter</filter-name>
<url-pattern>/*</url-pattern>
</filter-mapping>
<filter-mapping>
<filter-name>corsFilter</filter-name>
<url-pattern>/*</url-pattern>
</filter-mapping>
<servlet-mapping>
<servlet-name>SpringApplication</servlet-name>
<url-pattern>/SOME_API_URL/*</url-pattern>
</servlet-mapping>
<!-- THIS MUST ALWAYS BE BOUND TO PARENT CONTEXT -->
<servlet-mapping>
<servlet-name>spring</servlet-name>
<url-pattern>/</url-pattern>
</servlet-mapping>
</web-app>
@shlomicthailand
Copy link

very nice thanks, i have a comment regarding "THIS MUST ALWAYS BE BOUND TO PARENT CONTEXT"
i struggle with this for few days , and i could not let the spring dispatcher to handle all requests to root.
so i created a "virtual" prefix which is also set in the oauth/token + oauth/authorize endpoint definition
and i needed to create a filter that inherit from ClientCredentialsTokenEndpointFilter which initialize with the virtual prefix.
this is for everyone that not all their application is built on spring MVC

@mailuser41
Copy link

Could u plz help me in coverting the above web.xml into spring boot config. Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment