package ocidir

import (
	"errors"
	"testing"

	"github.com/opencontainers/go-digest"

	"github.com/regclient/regclient/scheme"
	"github.com/regclient/regclient/types/descriptor"
	"github.com/regclient/regclient/types/errs"
	"github.com/regclient/regclient/types/mediatype"
	v1 "github.com/regclient/regclient/types/oci/v1"
	"github.com/regclient/regclient/types/ref"
)

// Verify OCIDir implements various interfaces.
var (
	_ scheme.API       = (*OCIDir)(nil)
	_ scheme.Closer    = (*OCIDir)(nil)
	_ scheme.GCLocker  = (*OCIDir)(nil)
	_ scheme.Throttler = (*OCIDir)(nil)
)

func TestIndex(t *testing.T) {
	t.Parallel()
	// ctx := context.Background()
	tempDir := t.TempDir()
	o := New()
	dig1 := digest.FromString("test digest 1")
	dig2 := digest.FromString("test digest 2")
	dig3 := digest.FromString("test digest 3")
	r, err := ref.New("ocidir://" + tempDir + "/testrepo")
	if err != nil {
		t.Fatalf("failed to generate ref: %v", err)
	}
	rA := r.SetTag("tag-a")
	rB := r.SetTag("tag-b")
	rC := r.SetTag("tag-c")
	rDig := r.SetDigest(dig1.String())
	descNoTag := descriptor.Descriptor{
		MediaType: mediatype.Docker2Manifest,
		Size:      1234,
		Digest:    dig1,
	}
	descA := descriptor.Descriptor{
		MediaType: mediatype.Docker2Manifest,
		Size:      1234,
		Digest:    dig2,
		Annotations: map[string]string{
			aOCIRefName: "tag-a",
		},
	}
	descB := descriptor.Descriptor{
		MediaType: mediatype.Docker2Manifest,
		Size:      1234,
		Digest:    dig2,
		Annotations: map[string]string{
			aOCIRefName: "tag-b",
		},
	}
	descC := descriptor.Descriptor{
		MediaType: mediatype.Docker2Manifest,
		Size:      1234,
		Digest:    dig3,
		Annotations: map[string]string{
			aOCIRefName: rC.CommonName(),
		},
	}
	tests := []struct {
		name         string
		index        v1.Index
		get          ref.Ref
		expectGet    descriptor.Descriptor
		expectGetErr error
		set          ref.Ref
		setDesc      descriptor.Descriptor
		expectLen    int
	}{
		{
			name:         "empty",
			get:          rA,
			expectGetErr: errs.ErrNotFound,
		},
		{
			name: "no tag",
			index: v1.Index{
				Versioned: v1.IndexSchemaVersion,
				MediaType: mediatype.OCI1ManifestList,
				Manifests: []descriptor.Descriptor{
					descNoTag,
				},
			},
			get:       rDig,
			expectGet: descNoTag,
			set:       rA,
			setDesc:   descA,
			expectLen: 2,
		},
		{
			name: "tag a",
			index: v1.Index{
				Versioned: v1.IndexSchemaVersion,
				MediaType: mediatype.OCI1ManifestList,
				Manifests: []descriptor.Descriptor{
					descNoTag,
					descA,
				},
			},
			get:       rDig,
			expectGet: descNoTag,
			set:       rC,
			setDesc:   descNoTag,
			expectLen: 2,
		},
		{
			name: "tag b",
			index: v1.Index{
				Versioned: v1.IndexSchemaVersion,
				MediaType: mediatype.OCI1ManifestList,
				Manifests: []descriptor.Descriptor{
					descNoTag,
					descB,
				},
			},
			get:       rB,
			expectGet: descB,
			set:       rB,
			setDesc:   descNoTag,
			expectLen: 1,
		},
		{
			name: "tag c",
			index: v1.Index{
				Versioned: v1.IndexSchemaVersion,
				MediaType: mediatype.OCI1ManifestList,
				Manifests: []descriptor.Descriptor{
					descA,
					descC,
				},
			},
			get:       rC,
			expectGet: descC,
			set:       rA,
			setDesc:   descNoTag,
			expectLen: 2,
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			err := o.writeIndex(r, tt.index, false)
			if err != nil {
				t.Fatalf("failed to write index: %v", err)
			}
			index, err := o.readIndex(r, false)
			if err != nil {
				t.Fatalf("failed to read index: %v", err)
			}
			if !tt.get.IsZero() {
				d, err := indexGet(index, tt.get)
				if tt.expectGetErr != nil {
					if err == nil {
						t.Errorf("indexGet did not fail")
					} else if !errors.Is(err, tt.expectGetErr) && err.Error() != tt.expectGetErr.Error() {
						t.Errorf("unexpected error from indexGet, expected %v, received %v", tt.expectGetErr, err)
					}
				} else {
					if err != nil {
						t.Errorf("indexGet failed: %v", err)
					} else if !d.Equal(tt.expectGet) {
						t.Errorf("indexGet descriptor, expected %v, received %v", tt.expectGet, d)
					}
				}
			}
			if !tt.set.IsZero() {
				err := indexSet(&index, tt.set, tt.setDesc)
				if err != nil {
					t.Errorf("indexSet failed: %v", err)
				}
			}
			if len(index.Manifests) != tt.expectLen {
				t.Errorf("unexpected length, expected %d, found %d, index: %v", tt.expectLen, len(index.Manifests), index)
			}
		})
	}
}
