[ASTERIXDB-3514][EXT]: Add support to cross-account trust authentication

- user model changes: no
- storage format changes: no
- interface changes: no

Details:
AWS supports granting (trusting) permissions to services in
another account to access its resources without the need to
pass any permanent credentials.

Ext-ref: MB-63505
Change-Id: I30933ac3fef0ae2fb09a88a02bd89fd5087b7071
Reviewed-on: https://asterix-gerrit.ics.uci.edu/c/asterixdb/+/18946
Integration-Tests: Jenkins <jenkins@fulliautomatix.ics.uci.edu>
Tested-by: Jenkins <jenkins@fulliautomatix.ics.uci.edu>
Reviewed-by: Hussain Towaileb <hussainht@gmail.com>
Reviewed-by: Michael Blow <mblow@apache.org>
diff --git a/asterixdb/asterix-external-data/pom.xml b/asterixdb/asterix-external-data/pom.xml
index 21eaf71..8c8ad10 100644
--- a/asterixdb/asterix-external-data/pom.xml
+++ b/asterixdb/asterix-external-data/pom.xml
@@ -449,6 +449,10 @@
     </dependency>
     <dependency>
       <groupId>software.amazon.awssdk</groupId>
+      <artifactId>sts</artifactId>
+    </dependency>
+    <dependency>
+      <groupId>software.amazon.awssdk</groupId>
       <artifactId>s3</artifactId>
     </dependency>
     <dependency>
diff --git a/asterixdb/asterix-external-data/src/main/java/org/apache/asterix/external/util/aws/s3/S3Constants.java b/asterixdb/asterix-external-data/src/main/java/org/apache/asterix/external/util/aws/s3/S3Constants.java
index a62b346..126c868 100644
--- a/asterixdb/asterix-external-data/src/main/java/org/apache/asterix/external/util/aws/s3/S3Constants.java
+++ b/asterixdb/asterix-external-data/src/main/java/org/apache/asterix/external/util/aws/s3/S3Constants.java
@@ -28,6 +28,8 @@
     public static final String ACCESS_KEY_ID_FIELD_NAME = "accessKeyId";
     public static final String SECRET_ACCESS_KEY_FIELD_NAME = "secretAccessKey";
     public static final String SESSION_TOKEN_FIELD_NAME = "sessionToken";
+    public static final String ROLE_ARN_FIELD_NAME = "roleArn";
+    public static final String EXTERNAL_ID_FIELD_NAME = "externalId";
     public static final String SERVICE_END_POINT_FIELD_NAME = "serviceEndpoint";
 
     // AWS S3 specific error codes
diff --git a/asterixdb/asterix-external-data/src/main/java/org/apache/asterix/external/util/aws/s3/S3Utils.java b/asterixdb/asterix-external-data/src/main/java/org/apache/asterix/external/util/aws/s3/S3Utils.java
index 891d7f3..3cfccb4 100644
--- a/asterixdb/asterix-external-data/src/main/java/org/apache/asterix/external/util/aws/s3/S3Utils.java
+++ b/asterixdb/asterix-external-data/src/main/java/org/apache/asterix/external/util/aws/s3/S3Utils.java
@@ -30,6 +30,7 @@
 import static org.apache.asterix.external.util.aws.s3.S3Constants.ERROR_INTERNAL_ERROR;
 import static org.apache.asterix.external.util.aws.s3.S3Constants.ERROR_METHOD_NOT_IMPLEMENTED;
 import static org.apache.asterix.external.util.aws.s3.S3Constants.ERROR_SLOW_DOWN;
+import static org.apache.asterix.external.util.aws.s3.S3Constants.EXTERNAL_ID_FIELD_NAME;
 import static org.apache.asterix.external.util.aws.s3.S3Constants.HADOOP_ACCESS_KEY_ID;
 import static org.apache.asterix.external.util.aws.s3.S3Constants.HADOOP_ANONYMOUS_ACCESS;
 import static org.apache.asterix.external.util.aws.s3.S3Constants.HADOOP_CREDENTIAL_PROVIDER_KEY;
@@ -42,6 +43,7 @@
 import static org.apache.asterix.external.util.aws.s3.S3Constants.HADOOP_TEMP_ACCESS;
 import static org.apache.asterix.external.util.aws.s3.S3Constants.INSTANCE_PROFILE_FIELD_NAME;
 import static org.apache.asterix.external.util.aws.s3.S3Constants.REGION_FIELD_NAME;
+import static org.apache.asterix.external.util.aws.s3.S3Constants.ROLE_ARN_FIELD_NAME;
 import static org.apache.asterix.external.util.aws.s3.S3Constants.SECRET_ACCESS_KEY_FIELD_NAME;
 import static org.apache.asterix.external.util.aws.s3.S3Constants.SERVICE_END_POINT_FIELD_NAME;
 import static org.apache.asterix.external.util.aws.s3.S3Constants.SESSION_TOKEN_FIELD_NAME;
@@ -54,6 +56,7 @@
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
+import java.util.UUID;
 import java.util.function.BiPredicate;
 import java.util.regex.Matcher;
 
@@ -93,6 +96,10 @@
 import software.amazon.awssdk.services.s3.model.S3Object;
 import software.amazon.awssdk.services.s3.model.S3Response;
 import software.amazon.awssdk.services.s3.paginators.ListObjectsV2Iterable;
+import software.amazon.awssdk.services.sts.StsClient;
+import software.amazon.awssdk.services.sts.model.AssumeRoleRequest;
+import software.amazon.awssdk.services.sts.model.AssumeRoleResponse;
+import software.amazon.awssdk.services.sts.model.Credentials;
 
 public class S3Utils {
     private S3Utils() {
@@ -111,31 +118,16 @@
      * @throws CompilationException CompilationException
      */
     public static S3Client buildAwsS3Client(Map<String, String> configuration) throws CompilationException {
-        // TODO(Hussain): Need to ensure that all required parameters are present in a previous step
-        String instanceProfile = configuration.get(INSTANCE_PROFILE_FIELD_NAME);
-        String accessKeyId = configuration.get(ACCESS_KEY_ID_FIELD_NAME);
-        String secretAccessKey = configuration.get(SECRET_ACCESS_KEY_FIELD_NAME);
-        String sessionToken = configuration.get(SESSION_TOKEN_FIELD_NAME);
         String regionId = configuration.get(REGION_FIELD_NAME);
         String serviceEndpoint = configuration.get(SERVICE_END_POINT_FIELD_NAME);
 
+        Region region = validateAndGetRegion(regionId);
+        AwsCredentialsProvider credentialsProvider = buildCredentialsProvider(configuration);
+
         S3ClientBuilder builder = S3Client.builder();
-
-        // Credentials
-        AwsCredentialsProvider credentialsProvider =
-                buildCredentialsProvider(instanceProfile, accessKeyId, secretAccessKey, sessionToken);
-
+        builder.region(region);
         builder.credentialsProvider(credentialsProvider);
 
-        // Validate the region
-        List<Region> regions = S3Client.serviceMetadata().regions();
-        Optional<Region> selectedRegion = regions.stream().filter(region -> region.id().equals(regionId)).findFirst();
-
-        if (selectedRegion.isEmpty()) {
-            throw new CompilationException(S3_REGION_NOT_SUPPORTED, regionId);
-        }
-        builder.region(selectedRegion.get());
-
         // Validate the service endpoint if present
         if (serviceEndpoint != null) {
             try {
@@ -154,61 +146,32 @@
         return builder.build();
     }
 
-    public static AwsCredentialsProvider buildCredentialsProvider(String instanceProfile, String accessKeyId,
-            String secretAccessKey, String sessionToken) throws CompilationException {
+    public static AwsCredentialsProvider buildCredentialsProvider(Map<String, String> configuration)
+            throws CompilationException {
+        String arnRole = configuration.get(ROLE_ARN_FIELD_NAME);
+        String externalId = configuration.get(EXTERNAL_ID_FIELD_NAME);
+        String instanceProfile = configuration.get(INSTANCE_PROFILE_FIELD_NAME);
+        String accessKeyId = configuration.get(ACCESS_KEY_ID_FIELD_NAME);
+        String secretAccessKey = configuration.get(SECRET_ACCESS_KEY_FIELD_NAME);
 
-        // Credentials
-        AwsCredentialsProvider credentialsProvider;
-
-        // nothing provided, anonymous authentication
-        if (instanceProfile == null && accessKeyId == null && secretAccessKey == null && sessionToken == null) {
-            credentialsProvider = AnonymousCredentialsProvider.create();
+        if (noAuth(configuration)) {
+            return AnonymousCredentialsProvider.create();
+        } else if (arnRole != null) {
+            // TODO: Do auth validation and use existing credentials if exist already, if not, assume the role
+            return validateAndGetTrustAccountAuthentication(configuration);
         } else if (instanceProfile != null) {
-
-            // only "true" value is allowed
-            if (!instanceProfile.equalsIgnoreCase("true")) {
-                throw new CompilationException(INVALID_PARAM_VALUE_ALLOWED_VALUE, INSTANCE_PROFILE_FIELD_NAME, "true");
-            }
-
-            // no other authentication parameters are allowed
-            if (accessKeyId != null) {
-                throw new CompilationException(PARAM_NOT_ALLOWED_IF_PARAM_IS_PRESENT, ACCESS_KEY_ID_FIELD_NAME,
-                        INSTANCE_PROFILE_FIELD_NAME);
-            }
-            if (secretAccessKey != null) {
-                throw new CompilationException(PARAM_NOT_ALLOWED_IF_PARAM_IS_PRESENT, SECRET_ACCESS_KEY_FIELD_NAME,
-                        INSTANCE_PROFILE_FIELD_NAME);
-            }
-            if (sessionToken != null) {
-                throw new CompilationException(PARAM_NOT_ALLOWED_IF_PARAM_IS_PRESENT, SESSION_TOKEN_FIELD_NAME,
-                        INSTANCE_PROFILE_FIELD_NAME);
-            }
-            credentialsProvider = InstanceProfileCredentialsProvider.create();
+            return validateAndGetInstanceProfileAuthentication(configuration);
         } else if (accessKeyId != null || secretAccessKey != null) {
-            // accessKeyId authentication
-            if (accessKeyId == null) {
-                throw new CompilationException(REQUIRED_PARAM_IF_PARAM_IS_PRESENT, ACCESS_KEY_ID_FIELD_NAME,
-                        SECRET_ACCESS_KEY_FIELD_NAME);
-            }
-            if (secretAccessKey == null) {
-                throw new CompilationException(REQUIRED_PARAM_IF_PARAM_IS_PRESENT, SECRET_ACCESS_KEY_FIELD_NAME,
-                        ACCESS_KEY_ID_FIELD_NAME);
-            }
-
-            // use session token if provided
-            if (sessionToken != null) {
-                credentialsProvider = StaticCredentialsProvider
-                        .create(AwsSessionCredentials.create(accessKeyId, secretAccessKey, sessionToken));
-            } else {
-                credentialsProvider =
-                        StaticCredentialsProvider.create(AwsBasicCredentials.create(accessKeyId, secretAccessKey));
-            }
+            return validateAndGetAccessKeysAuthentications(configuration);
         } else {
-            // if only session token is provided, accessKeyId is required
-            throw new CompilationException(REQUIRED_PARAM_IF_PARAM_IS_PRESENT, ACCESS_KEY_ID_FIELD_NAME,
-                    SESSION_TOKEN_FIELD_NAME);
+            if (externalId != null) {
+                throw new CompilationException(REQUIRED_PARAM_IF_PARAM_IS_PRESENT, ROLE_ARN_FIELD_NAME,
+                        EXTERNAL_ID_FIELD_NAME);
+            } else {
+                throw new CompilationException(REQUIRED_PARAM_IF_PARAM_IS_PRESENT, ACCESS_KEY_ID_FIELD_NAME,
+                        SESSION_TOKEN_FIELD_NAME);
+            }
         }
-        return credentialsProvider;
     }
 
     /**
@@ -282,10 +245,22 @@
             throw new CompilationException(ErrorCode.PARAMETERS_REQUIRED, srcLoc, ExternalDataConstants.KEY_FORMAT);
         }
 
-        // Both parameters should be passed, or neither should be passed (for anonymous/no auth)
+        String arnRole = configuration.get(ROLE_ARN_FIELD_NAME);
+        String externalId = configuration.get(EXTERNAL_ID_FIELD_NAME);
         String accessKeyId = configuration.get(ACCESS_KEY_ID_FIELD_NAME);
         String secretAccessKey = configuration.get(SECRET_ACCESS_KEY_FIELD_NAME);
-        if (accessKeyId == null || secretAccessKey == null) {
+
+        if (arnRole != null) {
+            String notAllowed = getNonNull(configuration, ACCESS_KEY_ID_FIELD_NAME, SECRET_ACCESS_KEY_FIELD_NAME,
+                    SESSION_TOKEN_FIELD_NAME);
+            if (notAllowed != null) {
+                throw new CompilationException(PARAM_NOT_ALLOWED_IF_PARAM_IS_PRESENT, notAllowed,
+                        INSTANCE_PROFILE_FIELD_NAME);
+            }
+        } else if (externalId != null) {
+            throw new CompilationException(REQUIRED_PARAM_IF_PARAM_IS_PRESENT, ROLE_ARN_FIELD_NAME,
+                    EXTERNAL_ID_FIELD_NAME);
+        } else if (accessKeyId == null || secretAccessKey == null) {
             // If one is passed, the other is required
             if (accessKeyId != null) {
                 throw new CompilationException(REQUIRED_PARAM_IF_PARAM_IS_PRESENT, SECRET_ACCESS_KEY_FIELD_NAME,
@@ -528,7 +503,7 @@
     }
 
     public static Map<String, List<String>> S3ObjectsOfSingleDepth(Map<String, String> configuration, String container,
-            String prefix) throws CompilationException, HyracksDataException {
+            String prefix) throws CompilationException {
         // create s3 client
         S3Client s3Client = buildAwsS3Client(configuration);
         // fetch all the s3 objects
@@ -543,7 +518,7 @@
      * @param prefix                definition prefix
      */
     private static Map<String, List<String>> listS3ObjectsOfSingleDepth(S3Client s3Client, String container,
-            String prefix) throws HyracksDataException {
+            String prefix) {
         Map<String, List<String>> allObjects = new HashMap<>();
         ListObjectsV2Iterable listObjectsInterable;
         ListObjectsV2Request.Builder listObjectsBuilder =
@@ -580,4 +555,116 @@
         allObjects.put("folders", folders);
         return allObjects;
     }
+
+    public static Region validateAndGetRegion(String regionId) throws CompilationException {
+        List<Region> regions = S3Client.serviceMetadata().regions();
+        Optional<Region> selectedRegion = regions.stream().filter(region -> region.id().equals(regionId)).findFirst();
+
+        if (selectedRegion.isEmpty()) {
+            throw new CompilationException(S3_REGION_NOT_SUPPORTED, regionId);
+        }
+        return selectedRegion.get();
+    }
+
+    // TODO(htowaileb): Currently, trust-account is always assuming we have instance profile setup in place
+    private static AwsCredentialsProvider validateAndGetTrustAccountAuthentication(Map<String, String> configuration)
+            throws CompilationException {
+        String notAllowed = getNonNull(configuration, ACCESS_KEY_ID_FIELD_NAME, SECRET_ACCESS_KEY_FIELD_NAME,
+                SESSION_TOKEN_FIELD_NAME);
+        if (notAllowed != null) {
+            throw new CompilationException(PARAM_NOT_ALLOWED_IF_PARAM_IS_PRESENT, notAllowed,
+                    INSTANCE_PROFILE_FIELD_NAME);
+        }
+
+        String regionId = configuration.get(REGION_FIELD_NAME);
+        String arnRole = configuration.get(ROLE_ARN_FIELD_NAME);
+        String externalId = configuration.get(EXTERNAL_ID_FIELD_NAME);
+        Region region = validateAndGetRegion(regionId);
+
+        AssumeRoleRequest.Builder builder = AssumeRoleRequest.builder();
+        builder.roleArn(arnRole);
+        builder.roleSessionName(UUID.randomUUID().toString());
+        builder.durationSeconds(900); // minimum role assume duration = 900 seconds (15 minutes), make configurable?
+        if (externalId != null) {
+            builder.externalId(externalId);
+        }
+        AssumeRoleRequest request = builder.build();
+        AwsCredentialsProvider credentialsProvider = validateAndGetInstanceProfileAuthentication(configuration);
+
+        // TODO(htowaileb): We shouldn't assume role with each request, rather stored the received temporary credentials
+        // and refresh when expired.
+        // assume the role from the provided arn
+        try (StsClient stsClient =
+                StsClient.builder().region(region).credentialsProvider(credentialsProvider).build()) {
+            AssumeRoleResponse response = stsClient.assumeRole(request);
+            Credentials credentials = response.credentials();
+            return StaticCredentialsProvider.create(AwsSessionCredentials.create(credentials.accessKeyId(),
+                    credentials.secretAccessKey(), credentials.sessionToken()));
+        } catch (SdkException ex) {
+            throw new CompilationException(ErrorCode.EXTERNAL_SOURCE_ERROR, ex, getMessageOrToString(ex));
+        }
+    }
+
+    private static AwsCredentialsProvider validateAndGetInstanceProfileAuthentication(Map<String, String> configuration)
+            throws CompilationException {
+        String instanceProfile = configuration.get(INSTANCE_PROFILE_FIELD_NAME);
+
+        // only "true" value is allowed
+        if (!"true".equalsIgnoreCase(instanceProfile)) {
+            throw new CompilationException(INVALID_PARAM_VALUE_ALLOWED_VALUE, INSTANCE_PROFILE_FIELD_NAME, "true");
+        }
+
+        String notAllowed = getNonNull(configuration, ACCESS_KEY_ID_FIELD_NAME, SECRET_ACCESS_KEY_FIELD_NAME,
+                SESSION_TOKEN_FIELD_NAME);
+        if (notAllowed != null) {
+            throw new CompilationException(PARAM_NOT_ALLOWED_IF_PARAM_IS_PRESENT, notAllowed,
+                    INSTANCE_PROFILE_FIELD_NAME);
+        }
+        return InstanceProfileCredentialsProvider.create();
+    }
+
+    private static AwsCredentialsProvider validateAndGetAccessKeysAuthentications(Map<String, String> configuration)
+            throws CompilationException {
+        String accessKeyId = configuration.get(ACCESS_KEY_ID_FIELD_NAME);
+        String secretAccessKey = configuration.get(SECRET_ACCESS_KEY_FIELD_NAME);
+        String sessionToken = configuration.get(SESSION_TOKEN_FIELD_NAME);
+
+        // accessKeyId authentication
+        if (accessKeyId == null) {
+            throw new CompilationException(REQUIRED_PARAM_IF_PARAM_IS_PRESENT, ACCESS_KEY_ID_FIELD_NAME,
+                    SECRET_ACCESS_KEY_FIELD_NAME);
+        }
+        if (secretAccessKey == null) {
+            throw new CompilationException(REQUIRED_PARAM_IF_PARAM_IS_PRESENT, SECRET_ACCESS_KEY_FIELD_NAME,
+                    ACCESS_KEY_ID_FIELD_NAME);
+        }
+
+        String notAllowed = getNonNull(configuration, EXTERNAL_ID_FIELD_NAME);
+        if (notAllowed != null) {
+            throw new CompilationException(PARAM_NOT_ALLOWED_IF_PARAM_IS_PRESENT, notAllowed,
+                    INSTANCE_PROFILE_FIELD_NAME);
+        }
+
+        // use session token if provided
+        if (sessionToken != null) {
+            return StaticCredentialsProvider
+                    .create(AwsSessionCredentials.create(accessKeyId, secretAccessKey, sessionToken));
+        } else {
+            return StaticCredentialsProvider.create(AwsBasicCredentials.create(accessKeyId, secretAccessKey));
+        }
+    }
+
+    private static boolean noAuth(Map<String, String> configuration) {
+        return getNonNull(configuration, INSTANCE_PROFILE_FIELD_NAME, ROLE_ARN_FIELD_NAME, EXTERNAL_ID_FIELD_NAME,
+                ACCESS_KEY_ID_FIELD_NAME, SECRET_ACCESS_KEY_FIELD_NAME, SESSION_TOKEN_FIELD_NAME) == null;
+    }
+
+    private static String getNonNull(Map<String, String> configuration, String... fieldNames) {
+        for (String fieldName : fieldNames) {
+            if (configuration.get(fieldName) != null) {
+                return fieldName;
+            }
+        }
+        return null;
+    }
 }