Workload Identity Federation

import { Injectable, Logger } from '@nestjs/common';
import { Storage } from '@google-cloud/storage';
import { PassThrough, Readable } from 'stream';
import {
  AwsClient,
  AwsSecurityCredentials,
  AwsSecurityCredentialsSupplier,
} from 'google-auth-library';
import { fromNodeProviderChain } from '@aws-sdk/credential-providers';

interface CachedCredentials extends AwsSecurityCredentials {
  expiration?: Date;
}
@Injectable()
export class GCSService {
  private gcsClient: Storage;
  private cachedCredentials: CachedCredentials;
  private logger = new Logger(GCSService.name);
  constructor() {
    const decodedVertexAiAccount = JSON.parse(
      Buffer.from(process.env.VERTEX_AI_ACCOUNT_KEY ?? '', 'base64').toString('utf-8'),
    );
    if (decodedVertexAiAccount.type === 'external_account') {
      this.reinit_client_by_aws_external_account();
    } else {
      this.gcsClient = new Storage({ credentials: decodedVertexAiAccount });
    }
  }

  private isTokenExpired(credentials: CachedCredentials): boolean {
    if (!credentials.expiration) {
      // If no expiration, assume token might be expired and refresh
      return true;
    }

    // Check if token expires within next 5 minutes (300 seconds)
    const bufferTime = 5 * 60 * 1000; // 5 minutes in milliseconds
    const expirationTime = credentials.expiration.getTime();
    const currentTime = Date.now();

    return expirationTime - currentTime <= bufferTime;
  }

  async reinit_client_by_aws_external_account() {
    const decodedVertexAiAccount = JSON.parse(
      Buffer.from(process.env.VERTEX_AI_ACCOUNT_KEY ?? '', 'base64').toString('utf-8'),
    );
    this.logger.log('Reinitializing GCS client by AWS external account');
    this.logger.log(this.cachedCredentials, 'the cached credentials ');
    if (decodedVertexAiAccount.type !== 'external_account') {
      return;
    }
    if (this.cachedCredentials && !this.isTokenExpired(this.cachedCredentials)) {
      this.logger.log('Using cached AWS credentials');
      return;
    }

    class AwsSupplier implements AwsSecurityCredentialsSupplier {
      private readonly region: string;

      constructor() {
        this.region = process.env.AWS_S3_REGION ?? 'ap-east-1';
      }

      async getAwsRegion(): Promise<string> {
        return this.region;
      }

      async getAwsSecurityCredentials(): Promise<CachedCredentials> {
        // Retrieve the AWS credentails.
        try {
          const awsCredentialsProvider = fromNodeProviderChain();
          const awsCredentials = await awsCredentialsProvider();

          const newCredentials: CachedCredentials = {
            accessKeyId: awsCredentials.accessKeyId,
            secretAccessKey: awsCredentials.secretAccessKey,
            token: awsCredentials.sessionToken,
            expiration: awsCredentials.expiration,
          };
          return newCredentials;
        } catch (error) {
          throw new Error(`AWS credentials refresh failed: ${error.message}`);
        }
      }
    }
    const awsSupplier = new AwsSupplier();
    const awsCredentials = await awsSupplier.getAwsSecurityCredentials();
    this.cachedCredentials = awsCredentials;
    const clientOptions = {
      subjectTokenType: decodedVertexAiAccount.subject_token_type,
      audience: decodedVertexAiAccount.audience,
      service_account_impersonation_url: decodedVertexAiAccount.service_account_impersonation_url,
      aws_security_credentials_supplier: awsSupplier,
    };
    this.gcsClient = new Storage({
      authClient: new AwsClient(clientOptions),
    });
  }

  async listFiles() {
    await this.reinit_client_by_aws_external_account();
    const [files] = await this.gcsClient
      .bucket(process.env.GCS_INSIGHT_FILE_BUCKET ?? '')
      .getFiles();
    return files;
  }

  async uploadFile(bucketName: string, fileName: string, stream: PassThrough | Readable) {
    await this.reinit_client_by_aws_external_account();
    // Get a reference to the bucket
    const myBucket = this.gcsClient.bucket(bucketName);
    // Create a reference to a file object
    const file = myBucket.file(fileName);
    // Handle the source stream
    return new Promise<string>((resolve, reject) => {
      stream
        .pipe(file.createWriteStream())
        .on('error', (err) => {
          this.logger.error(err, `GCS - ${fileName} upload failed`);
          reject(err);
        })
        .on('finish', () => {
          // The file upload is complete
          this.logger.log(`GCS - ${fileName} uploaded to ${bucketName}`);
          resolve(fileName);
        });

      // Handle errors from the source stream as well
      stream.on('error', (err) => {
        this.logger.error(err, `Stream error while uploading ${fileName}`);
        reject(err);
      });
    });
  }
}

Last updated

Was this helpful?