Workload Identity Federation
import { Injectable, Logger } from '@nestjs/common';
import { Storage } from '@google-cloud/storage';
import { PassThrough, Readable } from 'stream';
import {
AwsClient,
AwsSecurityCredentials,
AwsSecurityCredentialsSupplier,
} from 'google-auth-library';
import { fromNodeProviderChain } from '@aws-sdk/credential-providers';
interface CachedCredentials extends AwsSecurityCredentials {
expiration?: Date;
}
@Injectable()
export class GCSService {
private gcsClient: Storage;
private cachedCredentials: CachedCredentials;
private logger = new Logger(GCSService.name);
constructor() {
const decodedVertexAiAccount = JSON.parse(
Buffer.from(process.env.VERTEX_AI_ACCOUNT_KEY ?? '', 'base64').toString('utf-8'),
);
if (decodedVertexAiAccount.type === 'external_account') {
this.reinit_client_by_aws_external_account();
} else {
this.gcsClient = new Storage({ credentials: decodedVertexAiAccount });
}
}
private isTokenExpired(credentials: CachedCredentials): boolean {
if (!credentials.expiration) {
// If no expiration, assume token might be expired and refresh
return true;
}
// Check if token expires within next 5 minutes (300 seconds)
const bufferTime = 5 * 60 * 1000; // 5 minutes in milliseconds
const expirationTime = credentials.expiration.getTime();
const currentTime = Date.now();
return expirationTime - currentTime <= bufferTime;
}
async reinit_client_by_aws_external_account() {
const decodedVertexAiAccount = JSON.parse(
Buffer.from(process.env.VERTEX_AI_ACCOUNT_KEY ?? '', 'base64').toString('utf-8'),
);
this.logger.log('Reinitializing GCS client by AWS external account');
this.logger.log(this.cachedCredentials, 'the cached credentials ');
if (decodedVertexAiAccount.type !== 'external_account') {
return;
}
if (this.cachedCredentials && !this.isTokenExpired(this.cachedCredentials)) {
this.logger.log('Using cached AWS credentials');
return;
}
class AwsSupplier implements AwsSecurityCredentialsSupplier {
private readonly region: string;
constructor() {
this.region = process.env.AWS_S3_REGION ?? 'ap-east-1';
}
async getAwsRegion(): Promise<string> {
return this.region;
}
async getAwsSecurityCredentials(): Promise<CachedCredentials> {
// Retrieve the AWS credentails.
try {
const awsCredentialsProvider = fromNodeProviderChain();
const awsCredentials = await awsCredentialsProvider();
const newCredentials: CachedCredentials = {
accessKeyId: awsCredentials.accessKeyId,
secretAccessKey: awsCredentials.secretAccessKey,
token: awsCredentials.sessionToken,
expiration: awsCredentials.expiration,
};
return newCredentials;
} catch (error) {
throw new Error(`AWS credentials refresh failed: ${error.message}`);
}
}
}
const awsSupplier = new AwsSupplier();
const awsCredentials = await awsSupplier.getAwsSecurityCredentials();
this.cachedCredentials = awsCredentials;
const clientOptions = {
subjectTokenType: decodedVertexAiAccount.subject_token_type,
audience: decodedVertexAiAccount.audience,
service_account_impersonation_url: decodedVertexAiAccount.service_account_impersonation_url,
aws_security_credentials_supplier: awsSupplier,
};
this.gcsClient = new Storage({
authClient: new AwsClient(clientOptions),
});
}
async listFiles() {
await this.reinit_client_by_aws_external_account();
const [files] = await this.gcsClient
.bucket(process.env.GCS_INSIGHT_FILE_BUCKET ?? '')
.getFiles();
return files;
}
async uploadFile(bucketName: string, fileName: string, stream: PassThrough | Readable) {
await this.reinit_client_by_aws_external_account();
// Get a reference to the bucket
const myBucket = this.gcsClient.bucket(bucketName);
// Create a reference to a file object
const file = myBucket.file(fileName);
// Handle the source stream
return new Promise<string>((resolve, reject) => {
stream
.pipe(file.createWriteStream())
.on('error', (err) => {
this.logger.error(err, `GCS - ${fileName} upload failed`);
reject(err);
})
.on('finish', () => {
// The file upload is complete
this.logger.log(`GCS - ${fileName} uploaded to ${bucketName}`);
resolve(fileName);
});
// Handle errors from the source stream as well
stream.on('error', (err) => {
this.logger.error(err, `Stream error while uploading ${fileName}`);
reject(err);
});
});
}
}
Last updated
Was this helpful?