############################################################################

# Copyright 2022 Workiva Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

############################################################################

import decimal
import json
import os
import sys
from enum import Enum

import requests

AUTH_URL = "https://api.app.wdesk.com/iam/v1/oauth2/token"
SS_API_URL = "https://api.app.wdesk.com/platform/v1/spreadsheets/"

# Capture parameters (inputs) when you run the Script
# from the Scripting Editor or Chain
CLIENT_ID = os.getenv("CLIENT_ID")
CLIENT_SECRET = os.getenv("CLIENT_SECRET")

DOCUMENT_ID = os.getenv("DOCUMENT_ID")
SHEET_ID = os.getenv("SHEET_ID")

# Override inputs when you run the Script from an integrated automation
IA_DOCUMENT_ID = os.getenv("DOCUMENT_ID")

if IA_DOCUMENT_ID:
    CLIENT_ID = ""  # Hardcode this if triggered from Integrated Automation
    CLIENT_SECRET = ""  # Hardcode this if triggered from Integrated Automation

    DOCUMENT_ID = IA_DOCUMENT_ID


class NumberPrecision(Enum):
    BASIS_POINTS = 0.0001
    HUNDREDTHS = 0.01
    ONES = 1
    THOUSANDS = 1000
    TEN_THOUSANDS = 10_000
    MILLIONS = 1_000_000
    HUNDRED_MILLIONS = 100_000_000
    BILLIONS = 1_000_000_000
    TRILLIONS = 1_000_000_000_000


class ApiAuth:
    def __init__(self):
        self._headers = {
            "Content-Type": "application/x-www-form-urlencoded;charset=UTF-8"
        }

    def get_auth_token(self):
        data = (
            f"client_id={CLIENT_ID}&client_secret={CLIENT_SECRET}"
            + "&grant_type=client_credentials"
        )
        token_res = requests.post(AUTH_URL, data=data, headers=self._headers)
        print("Auth response: ", token_res)
        token_response = json.loads(token_res.text)
        return token_response["access_token"]


class SpreadsheetApi:
    def __init__(self, access_token):
        self._headers = {"Authorization": f"Bearer {access_token}"}
        self._totalRowsHidden = 0

    def get_document_tables(self, doc_id):
        """
        Retrieves a list of identifiers for all the tables contained in a document.
        """
        ids = []
        url = f"{SS_API_URL}{doc_id}/sheets"
        response = requests.get(url, headers=self._headers)
        if not response.ok:
            print(f"Error getting document tables: {response.text}")
            sys.exit(1)
        tables = json.loads(response.text)["data"]
        for table in tables:
            ids.append(table["id"])
        return ids

    def get_table_data(self, doc_id, table_id):
        """
        Retrieves the data for a table.
        """
        url = f"{SS_API_URL}{doc_id}/sheets/{table_id}/sheetdata"
        response = requests.get(url, headers=self._headers)
        if not response.ok:
            print(f"Error getting table data: {response.text}")
            sys.exit(1)
        return json.loads(response.text, parse_float=decimal.Decimal)["data"]["cells"]

    def hide_table_rows(self, doc_id, table_id, row_indices):
        """
        Makes a request to hide the rows with indices specified in row_indices.
        """
        if not row_indices:
            return

        self._totalRowsHidden += len(row_indices)

        row_indices.sort()

        intervals = []
        start_index = row_indices[0]
        end_index = row_indices[0]

        for index in row_indices:
            if index > end_index + 1:
                interval = {"start": start_index, "end": end_index}
                intervals.append(interval)
                start_index = index
            end_index = index

        interval = {"start": start_index, "end": end_index}
        intervals.append(interval)

        url = f"{SS_API_URL}{doc_id}/sheets/{table_id}/update"
        response = requests.post(
            url,
            headers=self._headers,
            json={"hideRows": {"intervals": intervals}},
        )
        if not response.ok:
            print(f"Error hiding table rows: {response.text}")
            sys.exit(1)

    def unhide_table_rows(self, doc_id, table_id):
        """
        Makes a request to unhide all the rows within a table.
        """
        # By not specifying start and end they become infinite
        infinite_interval = {}
        intervals = {"unhideRows": {"intervals": [infinite_interval]}}

        url = f"{SS_API_URL}{doc_id}/sheets/{table_id}/update"
        response = requests.post(url, headers=self._headers, json=intervals)
        if not response.ok:
            print(f"Error unhiding table rows: {response.text}")
            sys.exit(1)

    def get_rows_as_displayed(self, doc_id, table_id):
        """
        Uses information from table data to create a list of rows with
        rounded and scaled display values.
        """
        rows_as_displayed = []
        for row in self.get_table_data(doc_id, table_id):
            row_as_displayed = []
            for cell in row:
                calculated_value = cell["calculatedValue"]
                if not isinstance(calculated_value, decimal.Decimal):
                    try:
                        float(calculated_value)
                        calculated_value = decimal.Decimal(calculated_value)
                    except ValueError:
                        pass

                displayed_value = calculated_value

                if type(displayed_value) is decimal.Decimal:
                    shown_in = cell["formats"]["valueFormat"]["shownIn"]
                    if shown_in:
                        shown_in_value = NumberPrecision[
                            shown_in.replace(" ", "_")
                        ].value
                        displayed_value /= shown_in_value

                    precision = cell["formats"]["valueFormat"]["precision"]
                    if precision and not precision["auto"]:
                        displayed_value = displayed_value.quantize(
                            decimal.Decimal(10) ** precision["value"],
                            decimal.ROUND_HALF_UP,
                        )

                row_as_displayed.append(displayed_value)
            rows_as_displayed.append(row_as_displayed)

        return rows_as_displayed

    def section_rows_to_hide(
        self,
        start_row,
        stop_row,
        zero_rows,
        has_numeric_data,
        has_non_zero_numeric_data,
    ):
        """
        Creates a list of row indicies to hide for a content section.
        """
        if has_non_zero_numeric_data:
            return zero_rows
        elif has_numeric_data:
            return range(start_row, stop_row + 1)

        # Don't hide sections with no numeric rows.
        return []

    def find_rows_to_hide(self, rows):
        """
        Creates a list of row indices that should be hidden, determined by
        whether the row is a zero row or part of a content section that
        consists entirely of zero rows.
        """
        rows_to_hide = []

        # Content section variables
        title_row = None
        zero_rows = []
        has_numeric_data = False
        has_non_zero_numeric_data = False

        for i, row in enumerate(rows):

            # Row variables
            is_spacer_row = True
            has_nums = False
            all_zeroes = True

            for cell in row:
                if cell:
                    is_spacer_row = False
                if type(cell) is decimal.Decimal:
                    has_nums = True
                    if cell != 0:  # This row has numbers, but is not a zero row
                        all_zeroes = False
                        break

            if is_spacer_row:
                if title_row:
                    rows_to_hide.extend(
                        self.section_rows_to_hide(
                            title_row,
                            i,
                            zero_rows,
                            has_numeric_data,
                            has_non_zero_numeric_data,
                        )
                    )
                    title_row = None
                    zero_rows = []
                    has_numeric_data = False
                    has_non_zero_numeric_data = False
            else:
                if not title_row:
                    title_row = i
                if has_nums:
                    has_numeric_data = True
                    if all_zeroes:
                        zero_rows.append(i)
                    else:
                        has_non_zero_numeric_data = True

        if title_row:
            rows_to_hide.extend(
                self.section_rows_to_hide(
                    title_row, i, zero_rows, has_numeric_data, has_non_zero_numeric_data
                )
            )

        return rows_to_hide

    def hide_rows(self, doc_id):
        """
        Hides all the zero rows and empty content sections in every table
        in the document specified by the document identifier.
        """
        table_ids = self.get_document_tables(doc_id)
        print(f"Hiding rows in {len(table_ids)} tables")
        for i, table_id in enumerate(table_ids):
            print(f"Hiding rows in table {i + 1}: {table_id}")
            rows_as_displayed = self.get_rows_as_displayed(doc_id, table_id)
            rows_to_hide = self.find_rows_to_hide(rows_as_displayed)
            self.hide_table_rows(doc_id, table_id, rows_to_hide)
        print(f"Total rows hidden: {self._totalRowsHidden}")

    def unhide_all_rows(self, doc_id):
        """
        Unhides every row in every table in the document specified by the
        document identifier.
        """
        table_ids = self.get_document_tables(doc_id)
        print(f"Unhiding {len(table_ids)} tables")
        for i, table_id in enumerate(table_ids):
            print(f"Unhiding table {i + 1}: {table_id}")
            self.unhide_table_rows(doc_id, table_id)


def main():

    auth_token = ApiAuth().get_auth_token()
    spreadsheet_api = SpreadsheetApi(auth_token)

    spreadsheet_api.unhide_all_rows(DOCUMENT_ID)
    spreadsheet_api.hide_rows(DOCUMENT_ID)


main()
