import {
  RegisterModelStep,
  ModelConfigStruct,
  RegisterModelArguments,
} from "@/@types/project/mlPipeline/SageMaker/pipeline-register-model";
import { parseUri, checkUseParam, parseFramework } from "./helper-functions";

// ----------------------------------------------------------------------

export default function parseRegisterModelStep(
  index: number,
  pipelineStep: any,
  stepPairs: [string, string, string][]
) {
  const args = pipelineStep.Arguments as RegisterModelArguments;
  const step = {
    id: String(index),
    type: "RegisterModel",
    name: pipelineStep.Name,
    approvalStatus: parseUri(args.ModelApprovalStatus),
    approvalStatusUseParam: checkUseParam(args.ModelApprovalStatus),
    model: parseModelPropsForRegister(
      args.ModelPackageGroupName,
      args.InferenceSpecification,
      pipelineStep.Name,
      stepPairs
    ),
    nodeX: 0,
    nodeY: 0,
    tags: [],
    properties: [],
    stepType: "RegisterModel",
    ...(args.ModelMetrics && {
      metrics: parseMetrics(args.ModelMetrics, pipelineStep.Name, stepPairs),
    }),
    ...(args.DriftCheckBaselines && {
      driftCheckBaselines: parseDriftCheckBaselines(
        args.DriftCheckBaselines,
        pipelineStep.Name,
        stepPairs
      ),
    }),
  } as RegisterModelStep;
  // console.log(step);
  return step;
}

const parseMetrics = (
  modelMetrics: any,
  stepName: string,
  stepPairs: [string, string, string][]
) => {
  const metrics = Object.keys(modelMetrics);
  let res = {};

  metrics.forEach((metric) => {
    if (metric === "ModelQuality") {
      res = {
        ...res,
        ...{
          modelStatistics: parseUri(
            modelMetrics.ModelQuality.Statistics.S3Uri,
            stepName,
            stepPairs,
            "ModelQuality.Statistics"
          ),
          modelStatisticsUseParam: checkUseParam(
            modelMetrics.ModelQuality.Statistics.S3Uri
          ),
          modelStatisticsContentType:
            modelMetrics.ModelQuality.Statistics.ContentType,
        },
      };
    } else if (metric === "Bias") {
      res = {
        ...res,
        ...{
          bias: parseUri(
            modelMetrics.Bias.PostTrainingReport.S3Uri,
            stepName,
            stepPairs,
            "Bias.PostTrainingReport"
          ),
          biasUseParam: checkUseParam(
            modelMetrics.Bias.PostTrainingReport.S3Uri
          ),
          biasContentType: modelMetrics.Bias.PostTrainingReport.ContentType,
        },
      };
    } else if (metric === "Explainability") {
      res = {
        ...res,
        ...{
          explainability: parseUri(
            modelMetrics.Explainability.Report.S3Uri,
            stepName,
            stepPairs,
            "Explainability.Report"
          ),
          explainabilityUseParam: checkUseParam(
            modelMetrics.Explainability.Report.S3Uri
          ),
          explainabilityContentType:
            modelMetrics.Explainability.Report.ContentType,
        },
      };
    }
  });

  return res;
};

const parseDriftCheckBaselines = (
  driftCheckBaselines: any,
  stepName: string,
  stepPairs: [string, string, string][]
) => {
  const baselines = Object.keys(driftCheckBaselines);
  let res = {};

  baselines.forEach((baseline) => {
    if (baseline === "Bias") {
      res = {
        ...res,
        ...{
          biasConfigFile: parseUri(
            driftCheckBaselines.Bias.PostTrainingConstraints.S3Uri,
            stepName,
            stepPairs,
            "Bias.PostTrainingConstraints"
          ),
          biasConfigFileUseParam: checkUseParam(
            driftCheckBaselines.Bias.PostTrainingConstraints.S3Uri
          ),
          biasConfigFileContentType:
            driftCheckBaselines.Bias.PostTrainingConstraints.ContentType,
        },
      };
    } else if (baseline === "Explainability") {
      res = {
        ...res,
        ...(driftCheckBaselines.Explainability.ConfigFile && {
          explainabilityConfigFile: parseUri(
            driftCheckBaselines.Explainability.ConfigFile.S3Uri,
            stepName,
            stepPairs,
            "Explainability.ConfigFile"
          ),
          explainabilityConfigFileUseParam: checkUseParam(
            driftCheckBaselines.Explainability.ConfigFile.S3Uri
          ),
          explainabilityConfigFileContentType:
            driftCheckBaselines.Explainability.ConfigFile.ContentType,
        }),
        ...(driftCheckBaselines.Explainability.Constraints && {
          explainabilityConstraints: parseUri(
            driftCheckBaselines.Explainability.Constraints.S3Uri,
            stepName,
            stepPairs,
            "Explainability.Constraints"
          ),
          explainabilityConstraintsUseParam: checkUseParam(
            driftCheckBaselines.Explainability.Constraints.S3Uri
          ),
          explainabilityConstraintsContentType:
            driftCheckBaselines.Explainability.Constraints.ContentType,
        }),
      };
    }
  });

  return res;
};

const parseModelPropsForRegister = (
  modelPackageGroup: any,
  inferenceSpecification: any,
  stepName: string,
  stepPairs: [string, string, string][]
) => {
  const modelProperties: ModelConfigStruct = {
    modelType: "Model",
    modelData: parseUri(
      inferenceSpecification.Containers[0].ModelDataUrl,
      stepName,
      stepPairs,
      "ModelDataUrl"
    ),
    modelDataUseParam: checkUseParam(
      inferenceSpecification.Containers[0].ModelDataUrl
    ),
    contentType: inferenceSpecification.SupportedContentTypes.join(","),
    responseType: inferenceSpecification.SupportedResponseMIMETypes.join(","),
    container: parseFramework(inferenceSpecification.Containers[0].Image),
    containerUseParam: checkUseParam(
      inferenceSpecification.Containers[0].Image
    ),
    modelPackageGroupName: parseUri(modelPackageGroup),
    modelPackageGroupNameUseParam: checkUseParam(modelPackageGroup),
    inferenceInstanceType: parseInstanceTypes(
      inferenceSpecification.SupportedRealtimeInferenceInstanceTypes
    ),
    inferenceInstanceTypeUseParam: checkInstanceTypeUseParam(
      inferenceSpecification.SupportedRealtimeInferenceInstanceTypes
    ),
    transformInstanceType: parseInstanceTypes(
      inferenceSpecification.SupportedTransformInstanceTypes
    ),
    transformInstanceTypeUseParam: checkInstanceTypeUseParam(
      inferenceSpecification.SupportedTransformInstanceTypes
    ),
  };

  return modelProperties;
};

const parseInstanceTypes = (input: any[]) => {
  if (input.length === 0) {
    return "";
  } else if (input.length === 1) {
    return parseUri(input[0]);
  } else {
    return input.join(",");
  }
};

const checkInstanceTypeUseParam = (input: any[]) => {
  if (input.length === 0) {
    return false;
  } else if (input.length === 1) {
    return checkUseParam(input[0]);
  } else {
    return false;
  }
};
