import { Schema, SchemaEdgeEntity } from "../../slices/slice";
import {
  Step,
  EnvVariable,
  PipelineManager,
  PipelineParameter,
} from "@/@types/project/mlPipeline/SageMaker/pipeline";
import {
  generateFailStep,
  generateModelStep,
  generateTuningStep,
  generateLambdaStep,
  generateTrainingStep,
  generateConditionStep,
  generateTransformStep,
  generateProcessingStep,
  generateClarifyCheckStep,
  generateRegisterModelStep,
} from "../step-generators";
import { FailStep } from "@/@types/project/mlPipeline/SageMaker/pipeline-fail";
import { ModelStep } from "@/@types/project/mlPipeline/SageMaker/pipeline-model";
import { LambdaStep } from "@/@types/project/mlPipeline/SageMaker/pipeline-lambda";
import { TuningStep } from "@/@types/project/mlPipeline/SageMaker/pipeline-tuning";
import { TransformStep } from "@/@types/project/mlPipeline/SageMaker/pipeline-transform";
import { ConditionStep } from "@/@types/project/mlPipeline/SageMaker/pipeline-condition";
import { ProcessingStep } from "@/@types/project/mlPipeline/SageMaker/pipeline-processing";
import { ClarifyCheckStep } from "@/@types/project/mlPipeline/SageMaker/pipeline-clarify-check";
import { RegisterModelStep } from "@/@types/project/mlPipeline/SageMaker/pipeline-register-model";
import { TrainingStep, Hyperparameter } from "@/@types/pipeline-training";

export default async function generateStep(
  step: Step,
  schema: Schema,
  pipeline: PipelineManager
) {
  let generatedStepInfo = {
    Name: "",
    Type: "",
    Arguments: {},
  } as any;
  if (step.type === "Processing") {
    generatedStepInfo = generateProcessingStep(
      step as ProcessingStep,
      pipeline.RoleArn
    );
    // return generateProcessingStep(step as ProcessingStep, pipeline.RoleArn);
  } else if (step.type === "Training") {
    generatedStepInfo = generateTrainingStep(
      step as TrainingStep,
      pipeline.RoleArn,
      pipeline.pipelineParamList
    );
    // return generateTrainingStep(
    //   step as TrainingStep,
    //   pipeline.RoleArn,
    //   pipeline.pipelineParamList
    // );
  } else if (step.type === "Tuning") {
    generatedStepInfo = generateTuningStep(
      step as TuningStep,
      pipeline.RoleArn,
      pipeline.pipelineParamList
    );
    // return generateTuningStep(
    //   step as TuningStep,
    //   pipeline.RoleArn,
    //   pipeline.pipelineParamList
    // );
  } else if (step.type === "Model") {
    generatedStepInfo = generateModelStep(step as ModelStep, pipeline.RoleArn);
    // return generateModelStep(step as ModelStep, pipeline.RoleArn);
  } else if (step.type === "Condition") {
    generatedStepInfo = await generateConditionStep(
      step as ConditionStep,
      schema,
      pipeline
    );
    // return await generateConditionStep(step as ConditionStep, schema, pipeline);
  } else if (step.type === "RegisterModel") {
    generatedStepInfo = generateRegisterModelStep(step as RegisterModelStep);
    // return generateRegisterModelStep(step as RegisterModelStep);
  } else if (step.type === "Transform") {
    generatedStepInfo = generateTransformStep(step as TransformStep);
    // return generateTransformStep(step as TransformStep);
  } else if (step.type === "Fail") {
    generatedStepInfo = generateFailStep(step as FailStep);
    // return generateFailStep(step as FailStep);
  } else if (step.type === "Lambda") {
    generatedStepInfo = generateLambdaStep(step as LambdaStep);
    // return generateLambdaStep(step as LambdaStep);
  } else if (step.type === "ClarifyCheck") {
    generatedStepInfo = await generateClarifyCheckStep(
      step as ClarifyCheckStep,
      pipeline.RoleArn
    );
    // return await generateClarifyCheckStep(
    //   step as ClarifyCheckStep,
    //   pipeline.RoleArn
    // );
  }
  const dependedSteps = findParentSteps(
    step,
    schema.pipelineSteps,
    schema.edgeEntities
  );
  if (dependedSteps.length > 0) generatedStepInfo["DependsOn"] = dependedSteps;
  return generatedStepInfo;
}

const findParentSteps = (
  step: Step,
  steps: Step[],
  edgeEntities: SchemaEdgeEntity[]
) => {
  let parentSteps = [];
  for (const edge of edgeEntities) {
    if (edge.targetEntityId === step.id) {
      const parentStep = steps.find((step) => step.id === edge.sourceEntityId);
      if (parentStep && parentStep.type !== "Condition")
        parentSteps.push(parentStep.name);
    }
  }
  return parentSteps;
};

export const genArraysFromString = (input: string, useParam?: boolean) => {
  if (input.includes(",")) {
    return input.split(",");
  } else {
    return [generateUri(input, useParam)];
  }
};

// export const generateImageUri = (framework: string) => {
//   if (framework.toLowerCase().includes("pytorch"))
//     return "763104351884.dkr.ecr.us-east-2.amazonaws.com/pytorch-training:1.8.1-gpu-py3";
//   else if (framework.toLowerCase().includes("xgboost"))
//     return "257758044811.dkr.ecr.us-east-2.amazonaws.com/sagemaker-xgboost:1.0-1-cpu-py3";
//   else if (framework.toLowerCase().includes("sklearn"))
//     return "257758044811.dkr.ecr.us-east-2.amazonaws.com/sagemaker-scikit-learn:0.23-1-cpu-py3";
//   else if (framework.toLowerCase().includes("tensorflow")) return "TensorFlow";
//   return "";
// };

export const generateImageUri = (framework: string, useParam?: boolean) => {
  if (useParam !== undefined && useParam)
    return generateUri(framework, useParam);

  if (framework.includes(".com")) return framework;
  else {
    if (framework.toLowerCase().includes("xgboost"))
      return `257758044811.dkr.ecr.us-east-2.amazonaws.com/${framework.toLowerCase()}`;
    else if (framework.toLowerCase().includes("sklearn"))
      return `257758044811.dkr.ecr.us-east-2.amazonaws.com/${framework.toLowerCase()}`;
    else if (framework.toLowerCase().includes("tensorflow"))
      return `520713654638.dkr.ecr.us-east-2.amazonaws.com/${framework.toLowerCase()}`;
    return "";
  }
};

export const formEnvVars = (envVars: EnvVariable[] | undefined) => {
  if (envVars === undefined) return {};

  let envVarMap = new Map<string, string>();
  envVars.map((envVar) => envVarMap.set(envVar.name, envVar.value));

  // console.log(envVars);
  return Object.fromEntries(envVarMap);
};

export const formHyperparameters = (
  hps: Hyperparameter[] | undefined,
  pipelineParams: PipelineParameter[]
) => {
  if (hps === undefined) return {};

  let hpMap = new Map<string, string | Object>();
  hps.forEach((hp) => {
    hpMap.set(
      hp.name,
      hp.useParam ? formHyperparameterValue(hp.value, pipelineParams) : hp.value
    );
  });

  return Object.fromEntries(hpMap);
};

const formHyperparameterValue = (
  hpVal: string,
  pipelineParams: PipelineParameter[]
) => {
  const param = pipelineParams.find((parameter) => parameter.Name === hpVal);
  // console.log(param);

  if (param) {
    if (param.Type === "Integer" || param.Type === "Float") {
      return {
        "Std:Join": { On: "", Values: [{ Get: `Parameters.${hpVal}` }] },
      };
    } else {
      return { Get: `Parameters.${hpVal}` };
    }
  }

  return hpVal;
};

export const generateUri = (
  s3Uri?: string,
  s3UriUseParam?: boolean,
  dataType?: string
) => {
  if (s3Uri !== undefined && s3UriUseParam !== undefined) {
    return s3UriUseParam
      ? formInputData(s3Uri, s3UriUseParam)
      : reformUri(s3Uri.toString(), dataType);
  } else if (s3Uri !== undefined && s3UriUseParam === undefined) {
    return reformUri(s3Uri.toString(), dataType);
  } else {
    return "";
  }
};

export const generateJsonGet = (s3Uri: string, s3UriUseParam?: boolean) => {
  if (s3UriUseParam !== undefined) {
    return s3UriUseParam
      ? { Get: `Parameters.${s3Uri}` }
      : reformUri(s3Uri.toString());
  } else {
    return reformUri(s3Uri.toString());
  }
};

export const formPath = (
  path: string,
  bucket: string,
  pathUseParam?: boolean
) => {
  if (pathUseParam !== undefined) {
    return pathUseParam
      ? { Get: `Parameters.${path}` }
      : reformPathWRTBucket(path, bucket);
  } else {
    return reformPathWRTBucket(path, bucket);
  }
};

export const formNumber = (input: string, useParam?: boolean) => {
  if (useParam !== undefined) {
    return useParam ? { Get: `Parameters.${input}` } : Number(input);
  } else {
    return Number(input);
  }
};

export const formInputData = (data: string, dataUseParam?: boolean) => {
  if (dataUseParam !== undefined) {
    if (dataUseParam) {
      const items = data.split(".");
      if (items.length === 1) {
        return { Get: `Parameters.${data}` };
      } else {
        if (items[1] === "ProcessingOutputs") {
          return {
            Get: `Steps.${items[0]}.${items[1]}Config.Outputs['${items[2]}'].S3Output.S3Uri`.replace(
              "ProcessingOutputsConfig",
              "ProcessingOutputConfig"
            ),
          };
        } else if (items[1] === "ModelArtifacts") {
          return {
            Get: `Steps.${items[0]}.ModelArtifacts.S3ModelArtifacts`,
          };
        } else if (items[1] === "PropertyFiles") {
          return {
            Get: `Steps.${items[0]}.PropertyFiles.${items[2]}`,
          };
        } else if (items[1] === "ModelName") {
          return {
            Get: `Steps.${items[0]}.ModelName`,
          };
        } else if (items[1] === "ModelPackageArn") {
          return {
            Get: `Steps.${items[0]}.ModelPackageArn`,
          };
        } else if (items[1] === "CalculatedBaselineConstraints") {
          // Clarify Check
          return {
            Get: `Steps.${items[0]}.CalculatedBaselineConstraints`,
          };
        } else if (items[1] === "BaselineUsedForDriftCheckConstraints") {
          // Clarify Check
          return {
            Get: `Steps.${items[0]}.BaselineUsedForDriftCheckConstraints`,
          };
        }
      }
    } else {
      return data;
    }
  } else {
    return data;
  }
};

const reformPathWRTBucket = (path: string, bucket: string) => {
  return path.indexOf("s3:/") >= 0 ? path : `s3://${bucket}/${path}`;
};

const reformUri = (s3Uri: string, dataType?: string) => {
  const pos = s3Uri.indexOf("${");
  if (pos >= 0) {
    const pos2 = s3Uri.indexOf(".PropertyFiles.");
    if (pos2 >= 0) {
      const str = s3Uri.replace("${", "").replace("}", "");
      const items = str.split("/");
      return {
        "Std:JsonGet": {
          PropertyFile: {
            Get: `Steps.${items[0]}`,
          },
          Path: items.length > 1 ? items[1] : "",
        },
      };
    } else {
      const connector = s3Uri[pos - 1];
      const items = s3Uri.split(connector).filter((s) => s);
      const reformedItems = items.map((item) => {
        if (item === "s3:") {
          return "s3:/";
        } else if (item.indexOf("${") >= 0) {
          return { Get: item.split("{")[1].replace("}", "") };
        } else {
          return item;
        }
      });
      return { "Std:Join": { On: connector, Values: reformedItems } };
    }
  } else {
    return dataType && dataType === "number" ? Number(s3Uri) : s3Uri;
  }
};
