import {
  TuningStep,
  ParameterRangeStruct,
  IntegerParameterRange,
  ContinuousParameterRange,
  CategoricalParameterRange,
} from "@/@types/project/mlPipeline/SageMaker/pipeline-tuning";
import {
  StepInfo,
  PipelineParameter,
} from "@/@types/project/mlPipeline/SageMaker/pipeline";
import {
  formPath,
  formNumber,
  generateUri,
  formInputData,
  generateImageUri,
  formHyperparameters,
} from "./helper-functions";
import { TrainingInputStruct } from "@/@types/project/mlPipeline/SageMaker/pipeline-training";

// ----------------------------------------------------------------------

export default function generateTuningStep(
  step: TuningStep,
  roleArn: string,
  pipelineParams: PipelineParameter[]
) {
  return {
    Name: step.name,
    Type: step.type,
    Arguments: {
      HyperParameterTuningJobConfig: {
        Strategy: generateUri(
          step.tuningJob.strategyType,
          step.tuningJob.strategyTypeUseParam
        ),
        ResourceLimits: {
          MaxNumberOfTrainingJobs: formNumber(
            step.tuningJob.maxTrainingJob,
            step.tuningJob.maxTrainingJobUseParam
          ),
          MaxParallelTrainingJobs: formNumber(
            step.tuningJob.maxParallelTrainingJob,
            step.tuningJob.maxParallelTrainingJobUseParam
          ),
        },
        TrainingJobEarlyStoppingType: generateUri(
          step.tuningJob.trainingJobEarlyStoppingType,
          step.tuningJob.trainingJobEarlyStoppingTypeUseParam
        ),
        HyperParameterTuningJobObjective: {
          Type: generateUri(step.tuningType, step.tuningTypeUseParam),
          MetricName: step.metricName,
        },
        ParameterRanges: {
          ContinuousParameterRanges: genContinuousParamRanges(
            step.tuningJob.parameterRanges
          ),
          CategoricalParameterRanges: genCategoricalParamRanges(
            step.tuningJob.parameterRanges
          ),
          IntegerParameterRanges: genIntegerParamRanges(
            step.tuningJob.parameterRanges
          ),
        },
      },
      TrainingJobDefinition: {
        StaticHyperParameters: formHyperparameters(
          step.trainingJob.hyperparameters,
          pipelineParams
        ),
        OutputDataConfig: {
          S3OutputPath: formPath(
            step.trainingJob.outputPath,
            "",
            step.trainingJob.outputPathUseParam
          ),
        },
        HyperParameterTuningResourceConfig: {
          VolumeSizeInGB: formNumber(
            step.trainingJob.volumeSizeInGb,
            step.trainingJob.volumeSizeInGbUseParam
          ),
          InstanceCount: formNumber(
            step.trainingJob.instanceCount,
            step.trainingJob.instanceCountUseParam
          ),
          InstanceType: generateUri(
            step.trainingJob.instanceType,
            step.trainingJob.instanceTypeUseParam
          ),
        },
        AlgorithmSpecification: {
          TrainingInputMode: "File",
          TrainingImage: generateImageUri(
            step.trainingJob.estimatorType,
            step.trainingJob.estimatorTypeUseParam
          ),
        },
        InputDataConfig: step.trainingInputs.map((input) =>
          genTrainingInput(input)
        ),
        StoppingCondition: { MaxRuntimeInSeconds: 86400 },
        RoleArn: roleArn,
      },
    },
    CacheConfig: { Enabled: true, ExpireAfter: "30d" },
  } as StepInfo;
}

const genContinuousParamRanges = (paramRanges: ParameterRangeStruct[]) => {
  const ranges: Array<ContinuousParameterRange> = [];
  paramRanges.forEach((paramRange) => {
    paramRange.type === "Continuous" &&
      ranges.push({
        Name: paramRange.name,
        MinValue: generateUri(paramRange.minValue, paramRange.minValueUseParam),
        MaxValue: generateUri(paramRange.maxValue, paramRange.maxValueUseParam),
        ScalingType: generateUri(
          paramRange.scalingType,
          paramRange.scalingTypeUseParam
        ),
      });
  });
  return ranges;
};

const genIntegerParamRanges = (paramRanges: ParameterRangeStruct[]) => {
  const ranges: Array<IntegerParameterRange> = [];
  paramRanges.forEach((paramRange) => {
    paramRange.type === "Integer" &&
      ranges.push({
        Name: paramRange.name,
        MinValue: generateUri(paramRange.minValue, paramRange.minValueUseParam),
        MaxValue: generateUri(paramRange.maxValue, paramRange.maxValueUseParam),
        ScalingType: generateUri(
          paramRange.scalingType,
          paramRange.scalingTypeUseParam
        ),
      });
  });
  return ranges;
};

const genCategoricalParamRanges = (paramRanges: ParameterRangeStruct[]) => {
  const ranges: Array<CategoricalParameterRange> = [];
  paramRanges.forEach((paramRange) => {
    paramRange.type === "Categorical" &&
      ranges.push({
        Name: paramRange.name,
        Values: paramRange.value ? paramRange.value.split(",") : [],
      });
  });
  return ranges;
};

const genTrainingInput = (input: TrainingInputStruct) => {
  const {
    inputName,
    s3Data,
    s3DataUseParam,
    distribution,
    contentType,
    inputMode,
  } = input;

  return inputMode && inputMode !== "None"
    ? {
        DataSource: {
          S3DataSource: {
            S3DataType: "S3Prefix",
            S3Uri: formInputData(s3Data, s3DataUseParam),
            S3DataDistributionType: distribution,
          },
        },
        InputMode: inputMode,
        ContentType: contentType,
        ChannelName: inputName,
      }
    : {
        DataSource: {
          S3DataSource: {
            S3DataType: "S3Prefix",
            S3Uri: formInputData(s3Data, s3DataUseParam),
            S3DataDistributionType: distribution,
          },
        },
        ContentType: contentType,
        ChannelName: inputName,
      };
};
