/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"

#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
#include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"

namespace xla {
namespace gpu {
namespace {

using se::DeviceMemory;
using se::DeviceMemoryBase;
using se::Stream;
using se::dnn::AlgorithmConfig;
using se::dnn::BatchDescriptor;
using se::dnn::ConvolutionDescriptor;
using se::dnn::DataLayout;
using se::dnn::DimIndex;
using se::dnn::FilterDescriptor;
using se::dnn::FilterLayout;
using se::dnn::ProfileResult;

// A StreamExecutor ScratchAllocator that wraps a single XLA allocation,
// returning it (in its entirety) the first time Allocate() is called.
class ScratchBufAllocator : public se::ScratchAllocator {
 public:
  explicit ScratchBufAllocator(se::DeviceMemoryBase scratch)
      : scratch_(scratch) {}

  ~ScratchBufAllocator() override = default;

  int64 GetMemoryLimitInBytes() override { return scratch_.size(); }

  se::port::StatusOr<DeviceMemory<uint8>> AllocateBytes(
      int64 byte_size) override {
    if (allocated_) {
      return se::port::InternalError(
          "Can't allocate twice from a ScratchBufAllocator.");
    }
    if (byte_size > scratch_.size()) {
      return se::port::InternalError(absl::StrCat(
          "Can't allocate ", byte_size,
          " bytes from a ScratchBufAllocator of size ", scratch_.size()));
    }

    allocated_ = true;
    return se::DeviceMemory<uint8>(scratch_);
  }

 private:
  se::DeviceMemoryBase scratch_;
  bool allocated_ = false;
};

template <typename ElementType, typename OutputType>
Status RunGpuConvForward(GpuConvParams params,
                         se::ScratchAllocator* scratch_allocator,
                         se::Stream* stream, RunConvOptions options,
                         DeviceMemory<ElementType> input_buf,
                         DeviceMemory<ElementType> filter_buf,
                         DeviceMemory<OutputType> output_buf,
                         AlgorithmConfig algorithm) {
  if (params.config.conv_result_scale != 1) {
    return InternalError(
        "StreamExecutor doesn't support scaled convolution: %lf.",
        params.config.conv_result_scale);
  }
  return stream->ConvolveWithAlgorithm(
      params.config.input_descriptor, input_buf,
      params.config.filter_descriptor, filter_buf, params.config.conv_desc,
      params.config.output_descriptor, &output_buf, scratch_allocator,
      algorithm, options.profile_result);
}

template <typename ElementType, typename BiasType, typename OutputType>
Status RunGpuConvForwardActivation(GpuConvParams params,
                                   se::ScratchAllocator* scratch_allocator,
                                   se::Stream* stream, RunConvOptions options,
                                   DeviceMemory<ElementType> input_buf,
                                   DeviceMemory<ElementType> filter_buf,
                                   DeviceMemory<OutputType> output_buf,
                                   AlgorithmConfig algorithm) {
  BatchDescriptor bias_desc;
  bias_desc.set_count(1)
      .set_height(1)
      .set_width(1)
      .set_feature_map_count(
          params.config.output_descriptor.feature_map_count())
      .set_layout(params.config.output_descriptor.layout());

  se::DeviceMemory<OutputType> side_input(params.fusion->side_input_buf);
  // If there is no side input, use output as the side input.
  if (side_input.is_null()) {
    if (params.config.fusion->side_input_scale != 0) {
      return InternalError(
          "Side input scale is not 0, yet no side input buffer is "
          "provided");
    }
    // Since side-input scale is 0, the values in the side input don't
    // matter.  The simplest thing to do would be to pass in a null buffer
    // for the side input, but cudnn doesn't allow this.  cudnn does promise
    // that if side-input-scale is 0 the side input won't be read, so we
    // just pass in the output buffer, since it's handy and has the correct
    // size.
    side_input = output_buf;
  }

  return stream->FusedConvolveWithAlgorithm(
      params.config.input_descriptor, input_buf,
      params.config.conv_result_scale, params.config.filter_descriptor,
      filter_buf, params.config.conv_desc, side_input,
      params.config.fusion->side_input_scale, bias_desc,
      DeviceMemory<BiasType>(params.fusion->bias_buf),
      params.config.fusion->mode, params.config.output_descriptor, &output_buf,
      scratch_allocator, algorithm, options.profile_result);
}

// StreamExecutor supports various data types via overloading, and the support
// is maintained on-demand. To avoid calling into non-exist overloads, we have
// to carefully not call into them by using enable_if.
// TODO(timshen): Ideally, to avoid such complication in the runner, we can turn
// StreamExecutor overloadings to template functions, and for unsupported data
// types return runtime errors.
// This is the specialization for double, float, and half types.  All kinds of
// convolutions are supported here.
template <typename ElementType, typename BiasType, typename OutputType,
          typename std::enable_if<
              !std::is_integral<ElementType>::value>::type* = nullptr>
Status RunGpuConvInternalImpl(GpuConvParams params,
                              se::ScratchAllocator* scratch_allocator,
                              se::Stream* stream, RunConvOptions options,
                              DeviceMemory<ElementType> input_buf,
                              DeviceMemory<ElementType> filter_buf,
                              DeviceMemory<OutputType> output_buf,
                              AlgorithmConfig algorithm) {
  switch (params.config.kind) {
    case CudnnConvKind::kForward:
      return RunGpuConvForward(params, scratch_allocator, stream, options,
                               input_buf, filter_buf, output_buf, algorithm);
    case CudnnConvKind::kBackwardInput:
      if (params.config.conv_result_scale != 1) {
        return InternalError(
            "StreamExecutor doesn't support scaled convolution: %lf.",
            params.config.conv_result_scale);
      }
      return stream->ConvolveBackwardDataWithAlgorithm(
          params.config.filter_descriptor, filter_buf,
          params.config.output_descriptor, output_buf, params.config.conv_desc,
          params.config.input_descriptor, &input_buf, scratch_allocator,
          algorithm, options.profile_result);
      break;
    case CudnnConvKind::kBackwardFilter:
      if (params.config.conv_result_scale != 1) {
        return InternalError(
            "StreamExecutor doesn't support scaled convolution: %lf.",
            params.config.conv_result_scale);
      }
      return stream->ConvolveBackwardFilterWithAlgorithm(
          params.config.input_descriptor, input_buf,
          params.config.output_descriptor, output_buf, params.config.conv_desc,
          params.config.filter_descriptor, &filter_buf, scratch_allocator,
          algorithm, options.profile_result);
      break;
    case CudnnConvKind::kForwardActivation: {
      return RunGpuConvForwardActivation<ElementType, BiasType, OutputType>(
          params, scratch_allocator, stream, options, input_buf, filter_buf,
          output_buf, algorithm);
    }
  }
  return Status::OK();
}

// Specialization for integer types.  Only two forward convolutions are allowed.
template <typename ElementType, typename BiasType, typename OutputType,
          typename std::enable_if<std::is_integral<ElementType>::value>::type* =
              nullptr>
Status RunGpuConvInternalImpl(GpuConvParams params,
                              se::ScratchAllocator* scratch_allocator,
                              se::Stream* stream, RunConvOptions options,
                              DeviceMemory<ElementType> input_buf,
                              DeviceMemory<ElementType> filter_buf,
                              DeviceMemory<OutputType> output_buf,
                              AlgorithmConfig algorithm) {
  switch (params.config.kind) {
    case CudnnConvKind::kForward:
      return RunGpuConvForward(params, scratch_allocator, stream, options,
                               input_buf, filter_buf, output_buf, algorithm);
    case CudnnConvKind::kForwardActivation:
      return RunGpuConvForwardActivation<ElementType, BiasType, OutputType>(
          params, scratch_allocator, stream, options, input_buf, filter_buf,
          output_buf, algorithm);
    default:
      return InternalError(
          "Only convolution kinds kForward and kForwardActivation are "
          "supported for integer types");
  }
  return Status::OK();
}

template <typename ElementType, typename BiasType, typename OutputType>
Status RunGpuConvImpl(const GpuConvParams& params,
                      se::ScratchAllocator* scratch_allocator,
                      se::Stream* stream, RunConvOptions options) {
  auto input_buf = se::DeviceMemory<ElementType>(params.input_buf);
  auto filter_buf = se::DeviceMemory<ElementType>(params.filter_buf);
  auto output_buf = se::DeviceMemory<OutputType>(params.output_buf);
  AlgorithmConfig algorithm = params.config.algorithm;

  if (options.algo_override.has_value()) {
    algorithm = AlgorithmConfig(*options.algo_override);
    if (options.scratch_size_override.has_value()) {
      algorithm.set_scratch_size(*options.scratch_size_override);
    }
  }

  Status run_status = RunGpuConvInternalImpl<ElementType, BiasType, OutputType>(
      params, scratch_allocator, stream, options, input_buf, filter_buf,
      output_buf, algorithm);

  if (run_status != Status::OK()) {
    return run_status;
  }

  if (!stream->ok()) {
    return InternalError(
        "Unable to launch convolution with type %s and algorithm (%d, %s)",
        CudnnConvKindToString(params.config.kind),
        algorithm.algorithm()->algo_id(),
        algorithm.algorithm_no_scratch().has_value()
            ? absl::StrCat(algorithm.algorithm_no_scratch()->algo_id())
            : "none");
  }
  return Status::OK();
}

}  // anonymous namespace

StatusOr<GpuConvConfig> GetGpuConvConfig(
    const HloCustomCallInstruction* cudnn_call) {
  GpuConvConfig config;

  config.input_type = cudnn_call->operand(0)->shape().element_type();
  config.output_type = cudnn_call->shape().tuple_shapes(0).element_type();

  TF_ASSIGN_OR_RETURN(CudnnConvBackendConfig backend_config,
                      cudnn_call->backend_config<CudnnConvBackendConfig>());
  TF_ASSIGN_OR_RETURN(config.kind, GetCudnnConvKind(cudnn_call));

  // The third field is scratch size stored from conv_algorithm_picker
  // The operand is added to the shape field of the conv instruction
  // in GpuConvAlgorithmPicker::RunOnInstruction() call.
  config.algorithm = se::dnn::AlgorithmConfig(
      se::dnn::AlgorithmDesc(backend_config.algorithm(),
                             backend_config.tensor_ops_enabled()),
      cudnn_call->shape().tuple_shapes(1).dimensions(0));
  config.conv_result_scale = backend_config.conv_result_scale();

  Shape operand0_shape = cudnn_call->operand(0)->shape();
  Shape operand1_shape = cudnn_call->operand(1)->shape();
  Shape result_shape = cudnn_call->shape().tuple_shapes(0);

  switch (config.kind) {
    case CudnnConvKind::kForward:
    case CudnnConvKind::kForwardActivation:
      config.input_shape = operand0_shape;
      config.filter_shape = operand1_shape;
      config.output_shape = result_shape;
      break;
    case CudnnConvKind::kBackwardInput:
      config.input_shape = result_shape;
      config.filter_shape = operand1_shape;
      config.output_shape = operand0_shape;
      break;
    case CudnnConvKind::kBackwardFilter:
      config.input_shape = operand0_shape;
      config.filter_shape = result_shape;
      config.output_shape = operand1_shape;
      break;
    default:
      return InternalError("Unknown convolution kind");
  }

  if (config.kind == CudnnConvKind::kForwardActivation) {
    config.fusion.emplace();
    GpuConvConfig::FusionConfig& fusion = *config.fusion;
    if (!se::dnn::ActivationMode_IsValid(backend_config.activation_mode())) {
      return InternalError("Bad activation mode: %s",
                           backend_config.ShortDebugString());
    }
    fusion.mode =
        static_cast<se::dnn::ActivationMode>(backend_config.activation_mode());
    fusion.side_input_scale = backend_config.side_input_scale();
  }

  const Window& window = cudnn_call->window();
  const ConvolutionDimensionNumbers& dnums =
      cudnn_call->convolution_dimension_numbers();

  VLOG(3) << "Convolution Algorithm: "
          << config.algorithm.algorithm()->algo_id();
  VLOG(3) << "tensor_ops_enabled: "
          << config.algorithm.algorithm()->tensor_ops_enabled();
  VLOG(3) << "Convolution kind: " << CudnnConvKindToString(config.kind);
  VLOG(3) << "input shape: "
          << ShapeUtil::HumanStringWithLayout(config.input_shape);
  VLOG(3) << "filter shape: "
          << ShapeUtil::HumanStringWithLayout(config.filter_shape);
  VLOG(3) << "Output shape: "
          << ShapeUtil::HumanStringWithLayout(config.output_shape);
  VLOG(3) << "Window: { " << window.ShortDebugString() << " }";
  VLOG(3) << "Dim nums: { " << dnums.ShortDebugString() << " }";

  const int num_dimensions = window.dimensions_size();
  CHECK_LE(num_dimensions, 3) << cudnn_call->ToString();

  // cuDNN does not support 1D convolutions. We therefore express 1D
  // convolutions as 2D convolutions where the first spatial dimension is 1.
  // This matches the behavior of TF (see definition of conv1d in
  // tensorflow/python/ops/nn_ops.py).
  const int effective_num_dimensions = std::max(2, num_dimensions);

  // If one dimension is reversed, we need to have all dimensions reversed (so
  // we're doing convolution not cross correlation).
  const bool dims_reversed =
      window.dimensions_size() > 0 && window.dimensions()[0].window_reversal();

  CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size())
      << cudnn_call->ToString();
  CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size())
      << cudnn_call->ToString();
  CHECK_EQ(num_dimensions, dnums.output_spatial_dimensions_size())
      << cudnn_call->ToString();
  for (const WindowDimension& dim : window.dimensions()) {
    CHECK_EQ(dims_reversed, dim.window_reversal()) << cudnn_call->ToString();
    CHECK_EQ(dim.padding_low(), dim.padding_high()) << cudnn_call->ToString();
    CHECK_EQ(dim.base_dilation(), 1)
        << "cudnn does not support base dilation; it "
           "must be made explicit with a kPad: "
        << cudnn_call->ToString();
  }

  // cuDNN's convolution APIs support the BDYX layout for activations/output and
  // the OIYX layout for weights.
  DataLayout input_dl;
  FilterLayout filter_dl;
  DataLayout output_dl;

  const Shape* input_shape = &config.input_shape;
  const Shape* filter_shape = &config.filter_shape;
  const Shape* output_shape = &config.output_shape;

  TF_ASSIGN_OR_RETURN(std::tie(input_dl, filter_dl, output_dl),
                      XlaConvLayoutsToStreamExecutorLayouts(
                          dnums, input_shape->layout(), filter_shape->layout(),
                          output_shape->layout()));

  BatchDescriptor& input_descriptor = config.input_descriptor;
  input_descriptor = BatchDescriptor(effective_num_dimensions);
  input_descriptor.set_layout(input_dl)
      .set_feature_map_count(
          input_shape->dimensions(dnums.input_feature_dimension()))
      .set_count(input_shape->dimensions(dnums.input_batch_dimension()));
  for (int dim = 0; dim < num_dimensions; ++dim) {
    // Note that the dimensions are reversed. The same holds below.
    input_descriptor.set_spatial_dim(
        static_cast<DimIndex>(effective_num_dimensions - dim - 1),
        input_shape->dimensions(dnums.input_spatial_dimensions(dim)));
  }

  FilterDescriptor& filter_descriptor = config.filter_descriptor;
  filter_descriptor = FilterDescriptor(effective_num_dimensions);
  filter_descriptor.set_layout(filter_dl)
      .set_input_feature_map_count(
          filter_shape->dimensions(dnums.kernel_input_feature_dimension()))
      .set_output_feature_map_count(
          filter_shape->dimensions(dnums.kernel_output_feature_dimension()));
  for (int dim = 0; dim < num_dimensions; ++dim) {
    filter_descriptor.set_spatial_dim(
        static_cast<DimIndex>(effective_num_dimensions - dim - 1),
        filter_shape->dimensions(dnums.kernel_spatial_dimensions(dim)));
  }

  config.conv_desc = ConvolutionDescriptor(effective_num_dimensions);
  config.conv_desc.set_group_count(cudnn_call->feature_group_count());
  config.conv_desc.set_convolution_not_crosscorr(dims_reversed);
  for (int dim = 0; dim < num_dimensions; ++dim) {
    config.conv_desc
        .set_zero_padding(
            static_cast<DimIndex>(effective_num_dimensions - dim - 1),
            window.dimensions(dim).padding_low())
        .set_filter_stride(
            static_cast<DimIndex>(effective_num_dimensions - dim - 1),
            window.dimensions(dim).stride())
        .set_dilation_rate(
            static_cast<DimIndex>(effective_num_dimensions - dim - 1),
            window.dimensions(dim).window_dilation());
  }

  BatchDescriptor& output_descriptor = config.output_descriptor;
  output_descriptor = BatchDescriptor(effective_num_dimensions);
  output_descriptor.set_layout(output_dl)
      .set_feature_map_count(
          output_shape->dimensions(dnums.output_feature_dimension()))
      .set_count(output_shape->dimensions(dnums.output_batch_dimension()));
  for (int dim = 0; dim < num_dimensions; ++dim) {
    output_descriptor.set_spatial_dim(
        static_cast<DimIndex>(effective_num_dimensions - dim - 1),
        output_shape->dimensions(dnums.output_spatial_dimensions(dim)));
  }

  // Add a singleton dimension in the 1D convolution case.
  for (int dim = 0; dim < effective_num_dimensions - num_dimensions; dim++) {
    input_descriptor.set_spatial_dim(static_cast<DimIndex>(dim), 1);
    output_descriptor.set_spatial_dim(static_cast<DimIndex>(dim), 1);
    filter_descriptor.set_spatial_dim(static_cast<DimIndex>(dim), 1);
    config.conv_desc.set_zero_padding(static_cast<DimIndex>(dim), 0)
        .set_filter_stride(static_cast<DimIndex>(dim), 1);
  }

  return config;
}

StatusOr<GpuConvParams> GetGpuConvParams(
    const GpuConvConfig& config,
    absl::Span<se::DeviceMemoryBase> operand_buffers,
    se::DeviceMemoryBase result_buffer) {
  GpuConvParams params;
  params.config = config;

  switch (config.kind) {
    case CudnnConvKind::kForward:
    case CudnnConvKind::kForwardActivation:
      params.input_buf = operand_buffers[0];
      params.filter_buf = operand_buffers[1];
      params.output_buf = result_buffer;
      break;
    case CudnnConvKind::kBackwardInput:
      params.input_buf = result_buffer;
      params.filter_buf = operand_buffers[1];
      params.output_buf = operand_buffers[0];
      break;
    case CudnnConvKind::kBackwardFilter:
      params.input_buf = operand_buffers[0];
      params.filter_buf = result_buffer;
      params.output_buf = operand_buffers[1];
      break;
  }

  if (config.kind == CudnnConvKind::kForwardActivation) {
    params.fusion.emplace();
    GpuConvParams::FusionParams& fusion = *params.fusion;
    fusion.bias_buf = operand_buffers[2];
    if (operand_buffers.size() >= 4) {
      fusion.side_input_buf = operand_buffers[3];
    }
  }

  return params;
}

Status RunGpuConv(const gpu::GpuConvConfig& config,
                  absl::Span<se::DeviceMemoryBase> operand_buffers,
                  se::DeviceMemoryBase result_buffer,
                  se::DeviceMemoryBase scratch_buf, se::Stream* stream,
                  RunConvOptions options) {
  ScratchBufAllocator scratch_allocator(scratch_buf);
  return RunGpuConv(config, operand_buffers, result_buffer, &scratch_allocator,
                    stream, options);
}

Status RunGpuConv(const gpu::GpuConvConfig& config,
                  absl::Span<se::DeviceMemoryBase> operand_buffers,
                  se::DeviceMemoryBase result_buffer,
                  se::ScratchAllocator* scratch_allocator, se::Stream* stream,
                  RunConvOptions options) {
  TF_ASSIGN_OR_RETURN(GpuConvParams params,
                      GetGpuConvParams(config, operand_buffers, result_buffer));

  PrimitiveType input_primitive_type = config.input_type;
  switch (input_primitive_type) {
    case F16:
      return RunGpuConvImpl<Eigen::half, Eigen::half, Eigen::half>(
          params, scratch_allocator, stream, options);
    case F32:
      return RunGpuConvImpl<float, float, float>(params, scratch_allocator,
                                                 stream, options);
    case F64:
      return RunGpuConvImpl<double, double, double>(params, scratch_allocator,
                                                    stream, options);
    case S8: {
      PrimitiveType output_primitive_type = config.output_type;
      switch (output_primitive_type) {
        case F32:
          return RunGpuConvImpl<int8, float, float>(params, scratch_allocator,
                                                    stream, options);
        case S8:
          return RunGpuConvImpl<int8, float, int8>(params, scratch_allocator,
                                                   stream, options);
        default:
          return Unimplemented("Unimplemented convolution");
      }
    }
    default:
      return Unimplemented("Unimplemented convolution");
  }
}

}  // namespace gpu
}  // namespace xla
