import {
  TuningStep,
  TuningJobDefStruct,
  TrainingJobDefStruct,
  ParameterRangeStruct,
  TuningArguments,
} from "@/@types/project/mlPipeline/SageMaker/pipeline-tuning";
import { Hyperparameter } from "@/@types/project/mlPipeline/SageMaker/pipeline-training";
import {
  parseUri,
  parseS3Prefix,
  checkUseParam,
  parseFramework,
  parseFrameworkVersion,
} from "./helper-functions";
import { parseTrainingInputs } from "./training-step-parser";

// ----------------------------------------------------------------------

export default function parseTuningStep(
  index: number,
  pipelineStep: any,
  stepPairs: [string, string, string][]
) {
  const args = pipelineStep.Arguments as TuningArguments;
  // console.log(args);
  const step = {
    id: String(index),
    type: "Tuning",
    name: pipelineStep.Name,
    metricName: parseUri(
      args.HyperParameterTuningJobConfig.HyperParameterTuningJobObjective
        .MetricName
    ),
    tuningType: parseUri(
      args.HyperParameterTuningJobConfig.HyperParameterTuningJobObjective.Type
    ),
    tuningTypeUseParam: checkUseParam(
      args.HyperParameterTuningJobConfig.HyperParameterTuningJobObjective.Type
    ),
    tuningJob: parseTuningJobProps(args.HyperParameterTuningJobConfig),
    trainingJob: parseTrainingJobProps(args.TrainingJobDefinition),
    trainingInputs: parseTrainingInputs(
      args.TrainingJobDefinition.InputDataConfig,
      pipelineStep.Name,
      stepPairs
    ),
    nodeX: 220 * index,
    nodeY: 0,
    tags: [],
    properties: [],
    stepType: "Tuning",
  } as TuningStep;
  return step;
}

const parseTuningJobProps = (jobConfiguration: any) => {
  const tuningJobDefProps: TuningJobDefStruct = {
    strategyType: parseUri(jobConfiguration.Strategy),
    strategyTypeUseParam: checkUseParam(jobConfiguration.Strategy),
    maxTrainingJob: parseUri(
      jobConfiguration.ResourceLimits.MaxNumberOfTrainingJobs
    ),
    maxTrainingJobUseParam: checkUseParam(
      jobConfiguration.ResourceLimits.MaxNumberOfTrainingJobs
    ),
    maxParallelTrainingJob: parseUri(
      jobConfiguration.ResourceLimits.MaxParallelTrainingJobs
    ),
    maxParallelTrainingJobUseParam: checkUseParam(
      jobConfiguration.ResourceLimits.MaxParallelTrainingJobs
    ),
    trainingJobEarlyStoppingType: parseUri(
      jobConfiguration.TrainingJobEarlyStoppingType
    ),
    trainingJobEarlyStoppingTypeUseParam: checkUseParam(
      jobConfiguration.TrainingJobEarlyStoppingType
    ),
    parameterRanges: parseParameterRanges(
      jobConfiguration.ParameterRanges.ContinuousParameterRanges,
      jobConfiguration.ParameterRanges.CategoricalParameterRanges,
      jobConfiguration.ParameterRanges.IntegerParameterRanges
    ),
  };
  return tuningJobDefProps;
};

const parseTrainingJobProps = (jobConfiguration: any) => {
  // console.log(jobConfiguration);
  const trainingJobDefProps: TrainingJobDefStruct = {
    estimatorType: parseFramework(
      jobConfiguration.AlgorithmSpecification.TrainingImage
    ),
    estimatorTypeUseParam: checkUseParam(
      jobConfiguration.AlgorithmSpecification.TrainingImage
    ),
    frameworkVersion: parseFrameworkVersion(
      jobConfiguration.AlgorithmSpecification.TrainingImage
    ),
    instanceType: parseUri(
      jobConfiguration.HyperParameterTuningResourceConfig.InstanceType
    ),
    instanceTypeUseParam: checkUseParam(
      jobConfiguration.HyperParameterTuningResourceConfig.InstanceType
    ),
    instanceCount: parseUri(
      jobConfiguration.HyperParameterTuningResourceConfig.InstanceCount
    ),
    instanceCountUseParam: checkUseParam(
      jobConfiguration.HyperParameterTuningResourceConfig.InstanceCount
    ),
    volumeSizeInGb: parseUri(
      jobConfiguration.HyperParameterTuningResourceConfig.VolumeSizeInGB
    ),
    volumeSizeInGbUseParam: checkUseParam(
      jobConfiguration.HyperParameterTuningResourceConfig.VolumeSizeInGB
    ),
    outputPath: parseS3Prefix(jobConfiguration.OutputDataConfig.S3OutputPath),
    outputPathUseParam: checkUseParam(
      jobConfiguration.OutputDataConfig.S3OutputPath
    ),
    hyperparameters: parseHyperParameter(
      jobConfiguration.StaticHyperParameters
    ),
  };
  return trainingJobDefProps;
};

const parseParameterRanges = (
  continousPrs: any,
  categoricalPrs: any,
  integerPrs: any
) => {
  const parameterRanges: Array<ParameterRangeStruct> = [];
  for (let pr of continousPrs) {
    parameterRanges.push({
      name: pr.Name,
      type: "Continuous",
      minValue: parseUri(pr.MinValue),
      minValueUseParam: checkUseParam(pr.MinValue),
      maxValue: parseUri(pr.MaxValue),
      maxValueUseParam: checkUseParam(pr.MaxValue),
      scalingType: parseUri(pr.ScalingType),
      scalingTypeUseParam: checkUseParam(pr.ScalingType),
    });
  }
  for (let pr of categoricalPrs) {
    parameterRanges.push({
      name: pr.Name,
      type: "Continuous",
      minValue: parseUri(pr.MinValue),
      minValueUseParam: checkUseParam(pr.MinValue),
      maxValue: parseUri(pr.MaxValue),
      maxValueUseParam: checkUseParam(pr.MaxValue),
      scalingType: parseUri(pr.ScalingType),
      scalingTypeUseParam: checkUseParam(pr.ScalingType),
    });
  }
  for (let pr of integerPrs) {
    parameterRanges.push({
      name: pr.Name,
      type: "Continuous",
      minValue: parseUri(pr.MinValue),
      minValueUseParam: checkUseParam(pr.MinValue),
      maxValue: parseUri(pr.MaxValue),
      maxValueUseParam: checkUseParam(pr.MaxValue),
      scalingType: parseUri(pr.ScalingType),
      scalingTypeUseParam: checkUseParam(pr.ScalingType),
    });
  }
  return parameterRanges;
};

const parseHyperParameter = (hyperParameters: any) => {
  const hps: Array<Hyperparameter> = [];
  for (const [_key, _value] of Object.entries(hyperParameters)) {
    if (_key.indexOf("sagemaker_") >= 0) continue;
    hps.push({
      name: _key,
      value: parseUri(_value),
      useParam: checkUseParam(_value),
    });
  }
  return hps;
};
