Source code for nmdc_client.auth

# -*- coding: utf-8 -*-

import logging
from datetime import datetime, timedelta
from typing import Any

import requests

from nmdc_client.api_client import NMDCAPIClient
from nmdc_client.config import API_BASE_URL
from nmdc_client.decorators import has_deprecated_parameter

logger = logging.getLogger(__name__)


[docs] @has_deprecated_parameter("env", reason="Use ``api_base_url`` instead.") class NMDCAuth(NMDCAPIClient): """ Authentication handler for NMDC API operations. Parameters ---------- client_id The client ID for NMDC API authentication. See Notes for further details. client_secret The client secret for NMDC API authentication. See Notes for further details. username The username for NMDC API authentication. See Notes for further details. password The password for NMDC API authentication. See Notes for further details. api_base_url The base URL of an instance of the NMDC Runtime API. By default, this is the base URL of the production instance. env Deprecated. Use ``api_base_url`` instead. Previously used to specify the API environment (e.g., "prod", "dev"). Notes ----- Security Warning - your credentials should be stored in a secure location. Do not hard-code these values in your code; we recommend using environment variables. You must provide either: - ``client_id`` and ``client_secret`` (for client credentials grant), OR - ``username`` and ``password`` (for password grant). """ def __init__( self, client_id: str | None = None, client_secret: str | None = None, username: str | None = None, password: str | None = None, api_base_url: str = API_BASE_URL, env: str = "", ): super().__init__( api_base_url=api_base_url, env=env, ) self.client_id = client_id self.client_secret = client_secret self.username = username self.password = password self._token: str | None = None self._token_expires_at: datetime | None = None self._oauth_session: Any | None = None self.grant_type: str | None = ( "client_credentials" if (self.client_id and self.client_secret) else "password" if (self.username and self.password) else None )
[docs] def has_credentials(self) -> bool: """Check if the credentials are passed in properly.""" if self.client_id and self.client_secret: return True elif self.username and self.password: return True return False
[docs] def get_token(self) -> str: """Get a valid access token, refreshing if necessary.""" if self._is_token_valid(): assert isinstance(self._token, str) # to appease mypy return self._token return self._refresh_token()
def _is_token_valid(self) -> bool: """Check if current token is valid and not expired.""" if not self._token or not self._token_expires_at: return False return datetime.now() < self._token_expires_at def _refresh_token(self) -> str: """Refresh the access token.""" if self.grant_type == "client_credentials": token_request_body = { "grant_type": "client_credentials", "client_id": self.client_id, "client_secret": self.client_secret, } elif self.grant_type == "password": token_request_body = { "grant_type": "password", "username": self.username, "password": self.password, } else: raise ValueError( "Refreshing a token requires that credentials be specified." ) response = requests.post( f"{self.api_base_url}/token", headers=self._build_http_request_headers(), data=token_request_body, ) token_response = response.json() if "access_token" not in token_response: logger.error( f"Token refresh failed: {token_response}, Status: {response.status_code}" ) raise Exception(f"Token refresh failed: {token_response}") self._token = token_response["access_token"] # Handle expiry format expires_info = token_response.get("expires") if expires_info: days = expires_info.get("days", 0) hours = expires_info.get("hours", 0) minutes = expires_info.get("minutes", 0) expires_delta = timedelta(days=days, hours=hours, minutes=minutes) # Subtract 60s buffer self._token_expires_at = ( datetime.now() + expires_delta - timedelta(seconds=60) ) assert isinstance(self._token, str) # to appease mypy return self._token