Secure Spring Boot REST APIs with Amazon Cognito

In this tutorial, you will learn how to secure Spring Boot REST APIs using Amazon Cognito.

Amazon Cognito is an access management service that helps to secure your web and mobile applications easily and quickly.

In this example, we will:

  • Integrate a Spring Boot application with Amazon Cognito.
  • Create REST APIs for user registration, user login, and a protected API accessible only with a valid user access token.

To begin, your initial step involves establishing and configuring a user pool within AWS Cognito. To create and set up a user pool tailored to your application, please refer to our comprehensive guide on Amazon Cognito User Pool Setup.

Follow these steps to complete this tutorial:

  1. Visit Spring Initializr website at https://start.spring.io.
  2. Create a Spring Boot application with details as follows:
    • Project: Choose the project type (Maven or Gradle).
    • Language: Set the language to Java.
    • Spring Boot: Specify the Spring Boot version. The default selection is the latest stable version of Spring Boot, so you can leave it unchanged.
    • Project Metadata: Enter a Group and Artifact name for your project. The group name is the id of the project. Artifact is the name of your project. Add any necessary project metadata (description, package name, etc.)
    • Choose between packaging as a JAR (Java Archive) or a WAR (Web Application Archive) depends on how you plan to deploy your Spring Boot application. Choose JAR packaging if you want a standalone executable JAR file and WAR packaging if you intend to deploy your application to a Java EE application server or servlet container. When you package your Spring Boot application as a JAR using JAR packaging, it includes an embedded web server, such as Tomcat, by default. This means that you don't need to separately deploy your application to an external Tomcat server. Instead, you can run the JAR file directly, and the embedded Tomcat server will start and serve your application.
    • Select the Java version based on the compatibility requirements of your project. Consider the specific needs of your project, any compatibility requirements, and the Java version supported by your target deployment environment when making these choices.
  3. Add project dependencies:
    • Click the Add Dependencies button.
    • Choose the following dependencies: Spring Web, Spring Security, OAuth2 Resource Server, Lombok, and Spring Boot Dev Tools.

    Here's an example:



  4. Generate the project:
    • Click the Generate button.
    • Spring Initializr will generate a zip file containing your Spring Boot project.
  5. Download and extract the generated project:
    • Download the zip file generated by Spring Initializr.
    • Extract the contents of the zip file to a directory on your local machine.
  6. Import the project into your IDE:
    • Open your preferred IDE (IntelliJ IDEA, Eclipse, or Spring Tool Suite).
    • Import the extracted project as a Maven or Gradle project, depending on the build system you chose in Spring Initializr.
  7. Add Dependency:
  8. Add Amazon Cognito Java SDK dependency to the project.

    For Maven

    Add this to pom.xml file:

    <dependency>
        <groupId>com.amazonaws</groupId>
        <artifactId>aws-java-sdk-cognitoidp</artifactId>
        <version>1.12.552</version>
    </dependency>

    For Gradle

    Add this to build.gradle file:

    implementation group: 'com.amazonaws', name: 'aws-java-sdk-cognitoidp', version: '1.12.552'

  9. Add Configurations:
  10. Open the src/main/resources/application.properties file in your Eclipse editor and add the following configuration lines to the file:

    server.port=5000
    
    #aws
    aws.access-key = AKIASI5XVTY2ITHU53EF
    aws.access-secret = 3Ujqb6S/0gTG+CZvBuR53yOcFjpYylJeBfPU+XSY
    aws.default-region = us-east-1
    
    #cognito user pool
    aws.cognito.clientId=4amsm6k43n0b53frvpge0ucdd5
    aws.cognito.userPoolId=us-east-1_EGuu5pXd2
    aws.cognito.region=us-east-1
    aws.cognito.connectionTimeout=2000
    aws.cognito.readTimeout=2000
    
    #Replace us-east-1_EGuu5pXd2 with your userPoolId
    aws.cognito.jwk = https://cognito-idp.us-east-1.amazonaws.com/us-east-1_EGuu5pXd2/.well-known/jwks.json
    spring.security.oauth2.resourceserver.jwt.issuer-uri=https://cognito-idp.us-east-1.amazonaws.com/us-east-1_EGuu5pXd2
    
    logging.level.org.springframework=INFO
    logging.level.com.example=INFO
    
    # Logging pattern for console
    logging.pattern.console= %d{yyyy-MM-dd HH:mm:ss} - %msg%n

    Here's an explanation of the above configurations:

    server.port: This configuration line is used to specify the port number on which the server will listen for incoming requests. In this case, it sets the server port to 5000.

    aws.access-key and aws.access-secret: These are AWS access key and secret key, which are used for authentication when interacting with AWS services. Be cautious with these credentials, as they should be kept secure.

    aws.default-region: Specifies the default AWS region to be used by the application.

    aws.cognito.clientId: This defines the client ID, specific to your Cognito user pool.

    aws.cognito.userPoolId: This defines the user pool ID specific to your Cognito user pool.

    aws.cognito.region: Specifies the AWS region where your Cognito user pool is located.

    aws.cognito.connectionTimeout and aws.cognito.readTimeout: These settings specify the connection and read timeouts for interacting with Cognito.

    aws.cognito.jwk: This is the URL to fetch JSON Web Key Set (JWKS) from Cognito, which is used for JWT (JSON Web Token) verification.

    spring.security.oauth2.resourceserver.jwt.issuer-uri: Defines the issuer URI for JWT validation, which corresponds to your Cognito user pool.

    logging.level.org.springframework and logging.level.com.example: These settings configure the log levels for specific packages or classes in the application.

    logging.pattern.console: Specifies the log message format for the console output.


  11. Configure Security:
  12. Create a Spring configuration class that sets up and configures security for your Spring Boot application. It enables web security and method-level security using Spring Security:

    package com.example.app.security;
    
    import org.springframework.context.annotation.Bean;
    import org.springframework.context.annotation.Configuration;
    import org.springframework.security.config.Customizer;
    import org.springframework.security.config.annotation.method.configuration.EnableMethodSecurity;
    import org.springframework.security.config.annotation.web.builders.HttpSecurity;
    import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
    import org.springframework.security.config.http.SessionCreationPolicy;
    import org.springframework.security.web.SecurityFilterChain;
    import jakarta.servlet.http.HttpServletResponse;
    
    @Configuration
    @EnableWebSecurity(debug = false)
    @EnableMethodSecurity(prePostEnabled = true)
    public class SecurityConfig {
    
    
      @Bean
      public SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
        http.cors(cors -> cors.disable()).csrf(csrf -> csrf.disable())
            .authorizeHttpRequests((authz) -> authz.requestMatchers("/users/signup", "/users/signin")
                .permitAll().anyRequest().authenticated())
            .sessionManagement((sessionManagement) -> {
              sessionManagement.sessionCreationPolicy(SessionCreationPolicy.STATELESS);
            }).exceptionHandling(exceptionHandling -> exceptionHandling
                .authenticationEntryPoint((request, response, ex) -> {
                  response.sendError(HttpServletResponse.SC_UNAUTHORIZED, ex.getMessage());
                }))
            .oauth2ResourceServer((oauth2) -> oauth2.jwt(Customizer.withDefaults()));
    
        return http.build();
      }
    
    }

    Here, the code configures security settings for the Spring application. This filterChain(HttpSecurity http) method configures the security filter chain. It takes an HttpSecurity object as a parameter, which is provided by Spring Security for configuring security settings. The configuration disables CORS, CSRF protection, and sets up authorization. This sessionCreationPolicy(SessionCreationPolicy.STATELESS) sets the session creation policy to STATELESS, meaning that your application will not create HTTP sessions. This is common for stateless REST APIs. It allows all requests to the "/users/signup" and "/users/login" endpoints without authentication, but requires authentication for any other request. The oauth2ResourceServer((oauth2) -> oauth2.jwt(Customizer.withDefaults())) configuration is typically used in a Spring Security configuration to set up OAuth2 Resource Server support with JWT (JSON Web Token) authentication. The Customizer.withDefaults() method is used to apply default settings and configurations for JWT authentication provided by Spring Security.


  13. Configure AWS Cognito Client:
  14. Create a configuration class named CognitoConfig:

    package com.example.app.config;
    
    import org.springframework.beans.factory.annotation.Value;
    import org.springframework.context.annotation.Bean;
    import org.springframework.context.annotation.Configuration;
    import com.amazonaws.auth.AWSStaticCredentialsProvider;
    import com.amazonaws.auth.BasicAWSCredentials;
    import com.amazonaws.services.cognitoidp.AWSCognitoIdentityProvider;
    import com.amazonaws.services.cognitoidp.AWSCognitoIdentityProviderClientBuilder;
    
    @Configuration
    public class CognitoConfig {
      @Value(value = "${aws.access-key}")
      private String accessKey;
    
      @Value(value = "${aws.access-secret}")
      private String secretKey;
    
      @Bean
      public AWSCognitoIdentityProvider cognitoClient() {
    
        BasicAWSCredentials awsCreds = new BasicAWSCredentials(accessKey, secretKey);
    
        return AWSCognitoIdentityProviderClientBuilder.standard()
            .withCredentials(new AWSStaticCredentialsProvider(awsCreds)).withRegion("us-east-1")
            .build();
      }
    }

    This class is responsible for configuring and creating an instance of the AWS Cognito Identity Provider client to interact with the Amazon Cognito service. It sets up a Spring bean for the Amazon Cognito Identity Provider client, ensuring that it has the necessary AWS credentials and region configuration. This bean can then be used throughout the application to interact with the Amazon Cognito service for user authentication and management.


  15. Create Data Transfer Objects:
  16. Create a class named SignUpRequestDto as a Data Transfer Object (DTO) to represent the request payload for a user's sign up:

    package com.example.app.user.dto;
    
    import com.fasterxml.jackson.annotation.JsonInclude;
    import com.fasterxml.jackson.annotation.JsonInclude.Include;
    import lombok.AllArgsConstructor;
    import lombok.Builder;
    import lombok.Data;
    import lombok.NoArgsConstructor;
    
    @Data
    @Builder
    @AllArgsConstructor
    @NoArgsConstructor
    @JsonInclude(Include.NON_NULL)
    public class SignUpRequestDto {
    
      private String email;
      private String password;
      
    }

    Create a class named SignUpResponseDto as a Data Transfer Object (DTO) to represent the response data for a user's sign up:

    package com.example.app.user.dto;
    
    import com.fasterxml.jackson.annotation.JsonInclude;
    import com.fasterxml.jackson.annotation.JsonInclude.Include;
    import lombok.AllArgsConstructor;
    import lombok.Builder;
    import lombok.Data;
    import lombok.NoArgsConstructor;
    
    @Data
    @Builder
    @AllArgsConstructor
    @NoArgsConstructor
    @JsonInclude(Include.NON_NULL)
    public class SignUpResponseDto {
      
      private int statusCode;
      private String statusMessage;
      
    }

    Create a class named SignInRequestDto as a Data Transfer Object (DTO) to represent the request payload for a user's sign-in:

    package com.example.app.user.dto;
    
    import com.fasterxml.jackson.annotation.JsonInclude;
    import com.fasterxml.jackson.annotation.JsonInclude.Include;
    import lombok.AllArgsConstructor;
    import lombok.Builder;
    import lombok.Data;
    import lombok.NoArgsConstructor;
    
    @Data
    @Builder
    @AllArgsConstructor
    @NoArgsConstructor
    @JsonInclude(Include.NON_NULL)
    public class SignInRequestDto {
    
      private String email;
      private String password;
      private String newPassword;
    
    }

    Create a class named SignInResponseDto as a Data Transfer Object (DTO) to represent the response data for a user's sign-in:

    package com.example.app.user.dto;
    
    import com.fasterxml.jackson.annotation.JsonInclude;
    import com.fasterxml.jackson.annotation.JsonInclude.Include;
    import lombok.AllArgsConstructor;
    import lombok.Builder;
    import lombok.Data;
    import lombok.NoArgsConstructor;
    
    @Data
    @Builder
    @AllArgsConstructor
    @NoArgsConstructor
    @JsonInclude(Include.NON_NULL)
    public class SignInResponseDto {
    
      private String accessToken;
      private String refreshToken;
      private String idToken;
      private String tokenType;
      private String scope;
      private Integer expiresIn;
      
    }

  17. Create Custom Exceptions:
  18. Create classes to handle custom exceptions. Custom exceptions allow you to create specific exception types for your application that can be thrown when certain exceptional situations occur.

    Let's start by creating a Java class named Error with three private fields: message, status, and timestamp. This class represents data container that holds information related to an error:

    package com.example.app.exception.model;
    
    import lombok.Data;
    
    @Data
    public class Error {
    	private String message;
    	private int status;
    	private Long timestamp;
    }

    Create a custom exception class named ValidationException, which extends the RuntimeException class:

    package com.example.app.exception;
    
    public class ValidationException extends RuntimeException {
      private static final long serialVersionUID = 1L;
    
      public ValidationException(String message) {
        super(message);
      }
    
    }

    Create a Global Exception Handler class named GlobalExceptionHandlerController. The purpose of this class is to handle specific exceptions globally, providing consistent and customized error responses to clients when certain exceptions occur during the application's execution:

    package com.example.app.exception.controller;
    
    import java.util.Date;
    import org.springframework.http.HttpStatus;
    import org.springframework.http.ResponseEntity;
    import org.springframework.web.bind.annotation.ControllerAdvice;
    import org.springframework.web.bind.annotation.ExceptionHandler;
    import com.example.app.exception.ValidationException;
    import jakarta.servlet.http.HttpServletRequest;
    import com.example.app.exception.model.Error;
    
    @ControllerAdvice
    public class GlobalExceptionHandlerController {
    
      @ExceptionHandler(ValidationException.class)
      public ResponseEntity<Object> validation(ValidationException ex, HttpServletRequest request) {
        Error error = new Error();
        error.setMessage(ex.getMessage());
        error.setTimestamp(new Date().getTime());
        error.setStatus(HttpStatus.BAD_REQUEST.value());
        return new ResponseEntity<>(error, null, HttpStatus.BAD_REQUEST);
      }
    
    }

  19. Create Service:
  20. Create an interface named UserService that defines the contract for the user:

    package com.example.app.user.service;
    
    import com.example.app.user.dto.SignInRequestDto;
    import com.example.app.user.dto.SignInResponseDto;
    import com.example.app.user.dto.SignUpRequestDto;
    import com.example.app.user.dto.SignUpResponseDto;
    
    public interface UserService {
    
      SignUpResponseDto signUp(SignUpRequestDto signUpRequest);
    
      SignInResponseDto signIn(SignInRequestDto signInRequest);
    
    }

  21. Create Service Implementation:
  22. Create an implementation class that implements the UserService interface and handles the business logic. Let's create a class called UserServiceImpl:

    package com.example.app.user.service.impl;
    
    import java.util.HashMap;
    import java.util.Map;
    import org.springframework.beans.factory.annotation.Autowired;
    import org.springframework.beans.factory.annotation.Value;
    import org.springframework.stereotype.Service;
    import com.amazonaws.services.cognitoidp.AWSCognitoIdentityProvider;
    import com.amazonaws.services.cognitoidp.model.AdminCreateUserRequest;
    import com.amazonaws.services.cognitoidp.model.AdminCreateUserResult;
    import com.amazonaws.services.cognitoidp.model.AdminInitiateAuthRequest;
    import com.amazonaws.services.cognitoidp.model.AdminInitiateAuthResult;
    import com.amazonaws.services.cognitoidp.model.AdminRespondToAuthChallengeRequest;
    import com.amazonaws.services.cognitoidp.model.AdminRespondToAuthChallengeResult;
    import com.amazonaws.services.cognitoidp.model.AdminSetUserPasswordRequest;
    import com.amazonaws.services.cognitoidp.model.AttributeType;
    import com.amazonaws.services.cognitoidp.model.AuthFlowType;
    import com.amazonaws.services.cognitoidp.model.AuthenticationResultType;
    import com.amazonaws.services.cognitoidp.model.ChallengeNameType;
    import com.amazonaws.services.cognitoidp.model.DeliveryMediumType;
    import com.amazonaws.services.cognitoidp.model.MessageActionType;
    import com.example.app.exception.ValidationException;
    import com.example.app.user.dto.SignInRequestDto;
    import com.example.app.user.dto.SignInResponseDto;
    import com.example.app.user.dto.SignUpRequestDto;
    import com.example.app.user.dto.SignUpResponseDto;
    import com.example.app.user.service.UserService;
    
    @Service
    public class UserServiceImpl implements UserService {
    
      @Autowired
      private AWSCognitoIdentityProvider cognitoClient;
    
      @Value(value = "${aws.cognito.userPoolId}")
      private String userPoolId;
    
      @Value(value = "${aws.cognito.clientId}")
      private String clientId;
    
      @Override
      public SignUpResponseDto signUp(SignUpRequestDto signUpRequest) {
    
        SignUpResponseDto signUpResponse = new SignUpResponseDto();
        try {
    
          AttributeType emailAttr =
              new AttributeType().withName("email").withValue(signUpRequest.getEmail());
          AttributeType emailVerifiedAttr =
              new AttributeType().withName("email_verified").withValue("true");
    
          AdminCreateUserRequest userRequest = new AdminCreateUserRequest().withUserPoolId(userPoolId)
              .withUsername(signUpRequest.getEmail()).withTemporaryPassword(signUpRequest.getPassword())
              .withUserAttributes(emailAttr, emailVerifiedAttr)
              .withMessageAction(MessageActionType.SUPPRESS)
              .withDesiredDeliveryMediums(DeliveryMediumType.EMAIL);
    
          AdminCreateUserResult createUserResult = cognitoClient.adminCreateUser(userRequest);
    
          System.out.println("User " + createUserResult.getUser().getUsername()
              + " is created. Status: " + createUserResult.getUser().getUserStatus());
    
          // Disable force change password during first login
          AdminSetUserPasswordRequest adminSetUserPasswordRequest = new AdminSetUserPasswordRequest()
              .withUsername(signUpRequest.getEmail()).withUserPoolId(userPoolId)
              .withPassword(signUpRequest.getPassword()).withPermanent(true);
    
          cognitoClient.adminSetUserPassword(adminSetUserPasswordRequest);
          signUpResponse.setStatusCode(0);
          signUpResponse.setStatusMessage("Successfully created user account.");
        } catch (Exception e) {
          throw new ValidationException("Error during sign up : " + e.getMessage());
        }
        return signUpResponse;
      }
    
      @Override
      public SignInResponseDto signIn(SignInRequestDto signInRequest) {
        SignInResponseDto signInResponse = new SignInResponseDto();
    
        final Map<String, String> authParams = new HashMap<>();
        authParams.put("USERNAME", signInRequest.getEmail());
        authParams.put("PASSWORD", signInRequest.getPassword());
    
        final AdminInitiateAuthRequest authRequest = new AdminInitiateAuthRequest();
        authRequest.withAuthFlow(AuthFlowType.ADMIN_NO_SRP_AUTH).withClientId(clientId)
            .withUserPoolId(userPoolId).withAuthParameters(authParams);
    
        try {
          AdminInitiateAuthResult result = cognitoClient.adminInitiateAuth(authRequest);
    
          AuthenticationResultType authenticationResult = null;
    
          if (result.getChallengeName() != null && !result.getChallengeName().isEmpty()) {
    
            System.out.println("Challenge Name is " + result.getChallengeName());
    
            if (result.getChallengeName().contentEquals("NEW_PASSWORD_REQUIRED")) {
              if (signInRequest.getPassword() == null) {
                throw new ValidationException("User must change password " + result.getChallengeName());
    
              } else {
    
                final Map<String, String> challengeResponses = new HashMap<>();
                challengeResponses.put("USERNAME", signInRequest.getEmail());
                challengeResponses.put("PASSWORD", signInRequest.getPassword());
                // add new password
                challengeResponses.put("NEW_PASSWORD", signInRequest.getNewPassword());
    
                final AdminRespondToAuthChallengeRequest request =
                    new AdminRespondToAuthChallengeRequest()
                        .withChallengeName(ChallengeNameType.NEW_PASSWORD_REQUIRED)
                        .withChallengeResponses(challengeResponses).withClientId(clientId)
                        .withUserPoolId(userPoolId).withSession(result.getSession());
    
                AdminRespondToAuthChallengeResult resultChallenge =
                    cognitoClient.adminRespondToAuthChallenge(request);
                authenticationResult = resultChallenge.getAuthenticationResult();
    
                signInResponse.setAccessToken(authenticationResult.getAccessToken());
                signInResponse.setIdToken(authenticationResult.getIdToken());
                signInResponse.setRefreshToken(authenticationResult.getRefreshToken());
                signInResponse.setExpiresIn(authenticationResult.getExpiresIn());
                signInResponse.setTokenType(authenticationResult.getTokenType());
              }
    
            } else {
              throw new ValidationException("User has other challenge " + result.getChallengeName());
            }
          } else {
    
            System.out.println("User has no challenge");
            authenticationResult = result.getAuthenticationResult();
    
            signInResponse.setAccessToken(authenticationResult.getAccessToken());
            signInResponse.setIdToken(authenticationResult.getIdToken());
            signInResponse.setRefreshToken(authenticationResult.getRefreshToken());
            signInResponse.setExpiresIn(authenticationResult.getExpiresIn());
            signInResponse.setTokenType(authenticationResult.getTokenType());
          }
    
        } catch (Exception e) {
          throw new ValidationException(e.getMessage());
        }
        cognitoClient.shutdown();
        return signInResponse;
      }
    
    }

  23. Create a Web Controller:
  24. Create a controller class name UserController. It will handle HTTP requests and interact with the UserService:

    package com.example.app.user.controller;
    
    import java.util.Arrays;
    import org.springframework.beans.factory.annotation.Autowired;
    import org.springframework.http.MediaType;
    import org.springframework.http.ResponseEntity;
    import org.springframework.web.bind.annotation.GetMapping;
    import org.springframework.web.bind.annotation.PostMapping;
    import org.springframework.web.bind.annotation.RequestBody;
    import org.springframework.web.bind.annotation.RequestMapping;
    import org.springframework.web.bind.annotation.RestController;
    import com.example.app.user.dto.SignInRequestDto;
    import com.example.app.user.dto.SignInResponseDto;
    import com.example.app.user.dto.SignUpRequestDto;
    import com.example.app.user.dto.SignUpResponseDto;
    import com.example.app.user.service.UserService;
    
    @RestController
    @RequestMapping(path = "/users")
    public class UserController {
      @Autowired
      private UserService userService;
    
      @PostMapping(path = "/signup", consumes = {MediaType.APPLICATION_JSON_VALUE},
          produces = {MediaType.APPLICATION_JSON_VALUE})
      public ResponseEntity<SignUpResponseDto> signUp(@RequestBody SignUpRequestDto signUpRequest) {
        return ResponseEntity.ok(userService.signUp(signUpRequest));
      }
    
      @PostMapping(path = "/signin", consumes = {MediaType.APPLICATION_JSON_VALUE})
      public ResponseEntity<SignInResponseDto> signIn(@RequestBody SignInRequestDto loginDto) {
        return ResponseEntity.ok(userService.signIn(loginDto));
      }
    
      @GetMapping(path = "/data")
      public ResponseEntity<?> data() {
        return ResponseEntity.ok(Arrays.asList("Hello world!"));
      }
    }

    Here, the UserController class defines endpoints for user signup, login, and data retrieval. It interacts with a UserService to perform user-related operations and returns responses in JSON format.


  25. Run and Test your Application:
  26. Use your IDE's build tools (Maven or Gradle) to build your project and resolve dependencies. Once the build is successful, run the main class of your application. You should see logs indicating that the application has started.

    Test your endpoints by using API testing tools such as Postman:

    • User Sign up:


    • User Sign in:


    • Call the secured API using the access token received in the response after a successful login: