﻿// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeFixes;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.Diagnostics;
using Microsoft.CodeAnalysis.Text;
using Xunit;

namespace SourceGenerators.Tests
{
    internal static class RoslynTestUtils
    {
        /// <summary>
        /// Creates a canonical Roslyn workspace for testing.
        /// </summary>
        public static AdhocWorkspace CreateTestWorkspace()
        {
            AdhocWorkspace workspace = new AdhocWorkspace();
            workspace.AddSolution(SolutionInfo.Create(SolutionId.CreateNewId(), VersionStamp.Create()));
            return workspace;
        }

        /// <summary>
        /// Creates a canonical Roslyn project for testing.
        /// </summary>
        /// <param name="references">Assembly references to include in the project.</param>
        /// <param name="includeBaseReferences">Whether to include references to the BCL assemblies.</param>
        public static Project CreateTestProject(
            AdhocWorkspace workspace,
            IEnumerable<Assembly>? references,
            bool includeBaseReferences = true,
            LanguageVersion langVersion = LanguageVersion.Preview)
        {
            string corelib = Assembly.GetAssembly(typeof(object))!.Location;
            string runtimeDir = Path.GetDirectoryName(corelib)!;

            var refs = new List<MetadataReference>();
            if (includeBaseReferences)
            {
                refs.Add(MetadataReference.CreateFromFile(corelib));
                refs.Add(MetadataReference.CreateFromFile(Path.Combine(runtimeDir, "netstandard.dll")));
                refs.Add(MetadataReference.CreateFromFile(Path.Combine(runtimeDir, "System.Runtime.dll")));
            }

            if (references != null)
            {
                foreach (var r in references)
                {
                    refs.Add(MetadataReference.CreateFromFile(r.Location));
                }
            }

            return workspace
                .CurrentSolution
                .AddProject("Test", "test.dll", "C#")
                .WithMetadataReferences(refs)
                .WithCompilationOptions(new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary).WithNullableContextOptions(NullableContextOptions.Enable))
                .WithParseOptions(new CSharpParseOptions(langVersion));
        }

        public static Task CommitChanges(this Project proj, params string[] ignorables)
        {
            Assert.True(proj.Solution.Workspace.TryApplyChanges(proj.Solution));
            return AssertNoDiagnostic(proj, ignorables);
        }

        public static async Task AssertNoDiagnostic(this Project proj, params string[] ignorables)
        {
            foreach (Document doc in proj.Documents)
            {
                SemanticModel? sm = await doc.GetSemanticModelAsync(CancellationToken.None).ConfigureAwait(false);
                Assert.NotNull(sm);

                foreach (Diagnostic d in sm!.GetDiagnostics())
                {
                    bool ignore = ignorables.Any(ig => d.Id == ig);

                    Assert.True(ignore, d.ToString());
                }
            }
        }

        private static Project WithDocuments(this Project project, IEnumerable<string> sources, IEnumerable<string>? sourceNames = null)
        {
            int count = 0;
            Project result = project;
            if (sourceNames != null)
            {
                List<string> names = sourceNames.ToList();
                foreach (string s in sources)
                    result = result.WithDocument(names[count++], s);
            }
            else
            {
                foreach (string s in sources)
                    result = result.WithDocument($"src-{count++}.cs", s);
            }

            return result;
        }

        public static Project WithDocument(this Project proj, string name, string text)
        {
            return proj.AddDocument(name, text).Project;
        }

        public static Document FindDocument(this Project proj, string name)
        {
            foreach (Document doc in proj.Documents)
            {
                if (doc.Name == name)
                {
                    return doc;
                }
            }

            throw new FileNotFoundException(name);
        }

        /// <summary>
        /// Looks for /*N+*/ and /*-N*/ markers in a string and creates a TextSpan containing the enclosed text.
        /// </summary>
        public static TextSpan MakeSpan(string text, int spanNum)
        {
            int start = text.IndexOf($"/*{spanNum}+*/", StringComparison.Ordinal);
            if (start < 0)
            {
                throw new ArgumentOutOfRangeException(nameof(spanNum));
            }

            start += 6;

            int end = text.IndexOf($"/*-{spanNum}*/", StringComparison.Ordinal);
            if (end < 0)
            {
                throw new ArgumentOutOfRangeException(nameof(spanNum));
            }

            end -= 1;

            return new TextSpan(start, end - start);
        }

        /// <summary>
        /// Runs a Roslyn generator over a set of source files.
        /// </summary>
        public static async Task<(ImmutableArray<Diagnostic>, ImmutableArray<GeneratedSourceResult>)> RunGenerator(
#if ROSLYN4_0_OR_GREATER
            IIncrementalGenerator generator,
#else
            ISourceGenerator generator,
#endif
            IEnumerable<Assembly>? references,
            IEnumerable<string> sources,
            bool includeBaseReferences = true,
            LanguageVersion langVersion = LanguageVersion.Preview,
            CancellationToken cancellationToken = default)
        {
            using var workspace = CreateTestWorkspace();
            Project proj = CreateTestProject(workspace, references, includeBaseReferences, langVersion);
            proj = proj.WithDocuments(sources);
            Assert.True(proj.Solution.Workspace.TryApplyChanges(proj.Solution));
            Compilation? comp = await proj!.GetCompilationAsync(CancellationToken.None).ConfigureAwait(false);
            return RunGenerator(comp!, generator, cancellationToken);
        }

        /// <summary>
        /// Runs a Roslyn generator given a Compilation.
        /// </summary>
        public static (ImmutableArray<Diagnostic>, ImmutableArray<GeneratedSourceResult>) RunGenerator(
            Compilation compilation,
#if ROSLYN4_0_OR_GREATER
            IIncrementalGenerator generator,
#else
            ISourceGenerator generator,
#endif
            CancellationToken cancellationToken = default)
        {

            CSharpGeneratorDriver cgd = CSharpGeneratorDriver.Create(new[] { generator });
            GeneratorDriver gd = cgd.RunGenerators(compilation, cancellationToken);

            GeneratorDriverRunResult r = gd.GetRunResult();
            return (r.Results[0].Diagnostics, r.Results[0].GeneratedSources);
        }

        /// <summary>
        /// Runs a Roslyn analyzer over a set of source files.
        /// </summary>
        public static async Task<IList<Diagnostic>> RunAnalyzer(
            DiagnosticAnalyzer analyzer,
            IEnumerable<Assembly> references,
            IEnumerable<string> sources)
        {
            using var workspace = CreateTestWorkspace();
            Project proj = CreateTestProject(workspace, references);

            proj = proj.WithDocuments(sources);

            await proj.CommitChanges().ConfigureAwait(false);

            ImmutableArray<DiagnosticAnalyzer> analyzers = ImmutableArray.Create(analyzer);

            Compilation? comp = await proj!.GetCompilationAsync().ConfigureAwait(false);
            return await comp!.WithAnalyzers(analyzers).GetAllDiagnosticsAsync().ConfigureAwait(false);
        }

        /// <summary>
        /// Runs a Roslyn analyzer and fixer.
        /// </summary>
        public static async Task<IList<string>> RunAnalyzerAndFixer(
            DiagnosticAnalyzer analyzer,
            CodeFixProvider fixer,
            IEnumerable<Assembly> references,
            IEnumerable<string> sources,
            IEnumerable<string>? sourceNames = null,
            string? defaultNamespace = null,
            string? extraFile = null)
        {
            using var workspace = CreateTestWorkspace();
            Project proj = CreateTestProject(workspace, references);

            int count = sources.Count();
            proj = proj.WithDocuments(sources, sourceNames);

            if (defaultNamespace != null)
            {
                proj = proj.WithDefaultNamespace(defaultNamespace);
            }

            await proj.CommitChanges().ConfigureAwait(false);

            ImmutableArray<DiagnosticAnalyzer> analyzers = ImmutableArray.Create(analyzer);

            while (true)
            {
                Compilation? comp = await proj!.GetCompilationAsync().ConfigureAwait(false);
                ImmutableArray<Diagnostic> diags = await comp!.WithAnalyzers(analyzers).GetAllDiagnosticsAsync().ConfigureAwait(false);
                if (diags.IsEmpty)
                {
                    // no more diagnostics reported by the analyzers
                    break;
                }

                var actions = new List<CodeAction>();
                foreach (Diagnostic d in diags)
                {
                    Document? doc = proj.GetDocument(d.Location.SourceTree);

                    CodeFixContext context = new CodeFixContext(doc!, d, (action, _) => actions.Add(action), CancellationToken.None);
                    await fixer.RegisterCodeFixesAsync(context).ConfigureAwait(false);
                }

                if (actions.Count == 0)
                {
                    // nothing to fix
                    break;
                }

                ImmutableArray<CodeActionOperation> operations = await actions[0].GetOperationsAsync(CancellationToken.None).ConfigureAwait(false);
                Solution solution = operations.OfType<ApplyChangesOperation>().Single().ChangedSolution;
                Project? changedProj = solution.GetProject(proj.Id);
                if (changedProj != proj)
                {
                    proj = await RecreateProjectDocumentsAsync(changedProj!).ConfigureAwait(false);
                }
            }

            var results = new List<string>();

            if (sourceNames != null)
            {
                List<string> l = sourceNames.ToList();
                for (int i = 0; i < count; i++)
                {
                    SourceText s = await proj.FindDocument(l[i]).GetTextAsync().ConfigureAwait(false);
                    results.Add(ReplaceLineEndings(s.ToString()));
                }
            }
            else
            {
                for (int i = 0; i < count; i++)
                {
                    SourceText s = await proj.FindDocument($"src-{i}.cs").GetTextAsync().ConfigureAwait(false);
                    results.Add(ReplaceLineEndings(s.ToString()));
                }
            }

            if (extraFile != null)
            {
                SourceText s = await proj.FindDocument(extraFile).GetTextAsync().ConfigureAwait(false);
                results.Add(ReplaceLineEndings(s.ToString()));
            }

            return results;
        }

        public static bool CompareLines(string[] expectedLines, SourceText sourceText, out string message)
        {
            if (expectedLines.Length != sourceText.Lines.Count)
            {
                message = string.Format("Line numbers do not match. Expected: {0} lines, but generated {1}",
                    expectedLines.Length, sourceText.Lines.Count);
                return false;
            }
            int index = 0;
            foreach (TextLine textLine in sourceText.Lines)
            {
                string expectedLine = expectedLines[index];
                if (!expectedLine.Equals(textLine.ToString(), StringComparison.Ordinal))
                {
                    message = string.Format("Line {0} does not match.{1}Expected Line:{1}{2}{1}Actual Line:{1}{3}",
                        textLine.LineNumber + 1, Environment.NewLine, expectedLine, textLine);
                    return false;
                }
                index++;
            }
            message = string.Empty;
            return true;
        }

        private static async Task<Project> RecreateProjectDocumentsAsync(Project project)
        {
            foreach (DocumentId documentId in project.DocumentIds)
            {
                Document? document = project.GetDocument(documentId);
                document = await RecreateDocumentAsync(document!).ConfigureAwait(false);
                project = document.Project;
            }

            return project;
        }

        private static async Task<Document> RecreateDocumentAsync(Document document)
        {
            SourceText newText = await document.GetTextAsync().ConfigureAwait(false);
            return document.WithText(SourceText.From(newText.ToString(), newText.Encoding, newText.ChecksumAlgorithm));
        }

        private static string ReplaceLineEndings(string text) =>
#if NETCOREAPP
            text.ReplaceLineEndings("\n");
#else
            text.Replace("\r\n", "\n");
#endif
    }
}
