import { TrainingCodeStruct } from "@/@types/project/mlPipeline/SageMaker/pipeline";
import {
  TrainingStep,
  TrainingInputStruct,
  TrainingOutputStruct,
} from "@/@types/project/mlPipeline/SageMaker/pipeline-training";
import {
  StepInfo,
  PipelineParameter,
} from "@/@types/project/mlPipeline/SageMaker/pipeline";
import {
  formNumber,
  generateUri,
  formEnvVars,
  formInputData,
  generateImageUri,
  formHyperparameters,
} from "./helper-functions";

// ----------------------------------------------------------------------

export default function generateTrainingStep(
  step: TrainingStep,
  roleArn: string,
  pipelineParams: PipelineParameter[]
) {
  return {
    Name: step.name,
    Type: step.type,
    Arguments: {
      AlgorithmSpecification: {
        TrainingInputMode: "File",
        TrainingImage: generateImageUri(
          step.estimator.estimatorType,
          step.estimator.estimatorTypeUseParam
        ),
        MetricDefinitions: step.metricDefinitions,
        EnableSageMakerMetricsTimeSeries: true,
      },
      InputDataConfig: step.trainingInputs.map((input) =>
        genTrainingInput(input)
      ),
      OutputDataConfig: genTrainingOutput(step.trainingOutput),
      StoppingCondition: { MaxRuntimeInSeconds: 86400 },
      ResourceConfig: {
        VolumeSizeInGB: formNumber(
          step.estimator.volumeSizeInGb,
          step.estimator.volumeSizeInGbUseParam
        ),
        InstanceCount: formNumber(
          step.estimator.instanceCount,
          step.estimator.instanceCountUseParam
        ),
        InstanceType: generateUri(
          step.estimator.instanceType,
          step.estimator.instanceTypeUseParam
        ),
      },
      HyperParameters: {
        ...formHyperparameters(step.estimator.hyperparameters, pipelineParams),
        ...formTrainingCode(step.code),
        // sagemaker_submit_directory:
        //   "s3://sm-nlp-data-v11/ie-baseline/source/train/pytorch-training-2022-03-05-23-54-01-530/source/sourcedir.tar.gz",
        // sagemaker_program: "train.py",
        // sagemaker_container_log_level: "20",
        // sagemaker_job_name: "pytorch-training-2022-03-05-23-54-01-530",
        // sagemaker_region: "us-east-2",
      },
      RoleArn: roleArn,
      Environment: formEnvVars(step.estimator.environmentVars),
    },
    // CacheConfig: { Enabled: true, ExpireAfter: "PT1H" },
  } as StepInfo;
}

const genTrainingInput = (input: TrainingInputStruct) => {
  const {
    inputName,
    s3Data,
    s3DataUseParam,
    distribution,
    contentType,
    inputMode,
  } = input;

  return inputMode && inputMode !== "None"
    ? {
        DataSource: {
          S3DataSource: {
            S3DataType: "S3Prefix",
            S3Uri: formInputData(s3Data, s3DataUseParam),
            S3DataDistributionType: distribution,
          },
        },
        ContentType: contentType,
        InputMode: inputMode,
        ChannelName: inputName,
      }
    : {
        DataSource: {
          S3DataSource: {
            S3DataType: "S3Prefix",
            S3Uri: formInputData(s3Data, s3DataUseParam),
            S3DataDistributionType: distribution,
          },
        },
        ContentType: contentType,
        ChannelName: inputName,
      };
};

const genTrainingOutput = (output: TrainingOutputStruct) => {
  const { s3OutputPath, s3OutputPathUseParam } = output;
  return {
    S3OutputPath: formInputData(s3OutputPath, s3OutputPathUseParam),
  };
};

const formTrainingCode = (code: TrainingCodeStruct) => {
  if (!code.entryPoint || code.entryPoint.trim() === "") return {};

  return code.sourceDir && code.sourceDir.trim() !== ""
    ? {
        sagemaker_submit_directory: code.sourceDir,
        sagemaker_program: code.entryPoint,
        sagemaker_container_log_level: "20",
        // sagemaker_region: "us-east-2",
      }
    : {
        sagemaker_program: code.entryPoint,
        sagemaker_container_log_level: "20",
        // sagemaker_region: "us-east-2",
      };
};
