Skip to content

Commit e3a4f49

Browse files
committed
feat: implement creator_id factor
1 parent ba52a78 commit e3a4f49

File tree

6 files changed

+75
-5
lines changed

6 files changed

+75
-5
lines changed

plugin/filter/filter.go

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
// MemoFilterCELAttributes are the CEL attributes for memo.
1010
var MemoFilterCELAttributes = []cel.EnvOption{
1111
cel.Variable("content", cel.StringType),
12+
cel.Variable("creator_id", cel.IntType),
1213
// As the built-in timestamp type is deprecated, we use string type for now.
1314
// e.g., "2021-01-01T00:00:00Z"
1415
cel.Variable("create_time", cel.StringType),

server/router/api/v1/memo_service.go

+15-2
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@ func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosReq
120120
memoFind.OrderByTimeAsc = true
121121
}
122122
if request.Filter != "" {
123+
if err := s.validateFilter(ctx, request.Filter); err != nil {
124+
return nil, status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
125+
}
123126
memoFind.Filter = &request.Filter
124127
}
125128

@@ -129,8 +132,18 @@ func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosReq
129132
}
130133
if currentUser == nil {
131134
memoFind.VisibilityList = []store.Visibility{store.Public}
132-
} else if memoFind.CreatorID == nil || *memoFind.CreatorID != currentUser.ID {
133-
memoFind.VisibilityList = []store.Visibility{store.Public, store.Protected}
135+
} else {
136+
if memoFind.CreatorID == nil {
137+
internalFilter := fmt.Sprintf(`creator_id == %d || visibility in ["PUBLIC", "Protected"]`, currentUser.ID)
138+
if memoFind.Filter != nil {
139+
filter := fmt.Sprintf("(%s) && (%s)", *memoFind.Filter, internalFilter)
140+
memoFind.Filter = &filter
141+
} else {
142+
memoFind.Filter = &internalFilter
143+
}
144+
} else if *memoFind.CreatorID != currentUser.ID {
145+
memoFind.VisibilityList = []store.Visibility{store.Public, store.Protected}
146+
}
134147
}
135148

136149
workspaceMemoRelatedSetting, err := s.Store.GetWorkspaceMemoRelatedSetting(ctx)

store/db/mysql/memo_filter.go

+18-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err
5959
if err != nil {
6060
return err
6161
}
62-
if !slices.Contains([]string{"create_time", "update_time"}, identifier) {
62+
if !slices.Contains([]string{"creator_id", "create_time", "update_time", "visibility", "content"}, identifier) {
6363
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
6464
}
6565
value, err := filter.GetConstValue(v.CallExpr.Args[1])
@@ -121,6 +121,23 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err
121121
return err
122122
}
123123
ctx.Args = append(ctx.Args, valueStr)
124+
} else if identifier == "creator_id" {
125+
if operator != "=" && operator != "!=" {
126+
return errors.Errorf("invalid operator for %s", v.CallExpr.Function)
127+
}
128+
valueInt, ok := value.(int64)
129+
if !ok {
130+
return errors.New("invalid int value")
131+
}
132+
133+
var factor string
134+
if identifier == "creator_id" {
135+
factor = "`memo`.`creator_id`"
136+
}
137+
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", factor, operator)); err != nil {
138+
return err
139+
}
140+
ctx.Args = append(ctx.Args, valueInt)
124141
}
125142
case "@in":
126143
if len(v.CallExpr.Args) != 2 {

store/db/postgres/memo_filter.go

+18-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err
5959
if err != nil {
6060
return err
6161
}
62-
if !slices.Contains([]string{"create_time", "update_time", "visibility", "content"}, identifier) {
62+
if !slices.Contains([]string{"creator_id", "create_time", "update_time", "visibility", "content"}, identifier) {
6363
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
6464
}
6565
value, err := filter.GetConstValue(v.CallExpr.Args[1])
@@ -121,6 +121,23 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err
121121
return err
122122
}
123123
ctx.Args = append(ctx.Args, valueStr)
124+
} else if identifier == "creator_id" {
125+
if operator != "=" && operator != "!=" {
126+
return errors.Errorf("invalid operator for %s", v.CallExpr.Function)
127+
}
128+
valueInt, ok := value.(int64)
129+
if !ok {
130+
return errors.New("invalid int value")
131+
}
132+
133+
var factor string
134+
if identifier == "creator_id" {
135+
factor = "memo.creator_id"
136+
}
137+
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", factor, operator)); err != nil {
138+
return err
139+
}
140+
ctx.Args = append(ctx.Args, valueInt)
124141
}
125142
case "@in":
126143
if len(v.CallExpr.Args) != 2 {

store/db/sqlite/memo_filter.go

+18-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err
5959
if err != nil {
6060
return err
6161
}
62-
if !slices.Contains([]string{"create_time", "update_time", "visibility", "content"}, identifier) {
62+
if !slices.Contains([]string{"creator_id", "create_time", "update_time", "visibility", "content"}, identifier) {
6363
return errors.Errorf("invalid identifier for %s", v.CallExpr.Function)
6464
}
6565
value, err := filter.GetConstValue(v.CallExpr.Args[1])
@@ -121,6 +121,23 @@ func (d *DB) ConvertExprToSQL(ctx *filter.ConvertContext, expr *exprv1.Expr) err
121121
return err
122122
}
123123
ctx.Args = append(ctx.Args, valueStr)
124+
} else if identifier == "creator_id" {
125+
if operator != "=" && operator != "!=" {
126+
return errors.Errorf("invalid operator for %s", v.CallExpr.Function)
127+
}
128+
valueInt, ok := value.(int64)
129+
if !ok {
130+
return errors.New("invalid int value")
131+
}
132+
133+
var factor string
134+
if identifier == "creator_id" {
135+
factor = "`memo`.`creator_id`"
136+
}
137+
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s ?", factor, operator)); err != nil {
138+
return err
139+
}
140+
ctx.Args = append(ctx.Args, valueInt)
124141
}
125142
case "@in":
126143
if len(v.CallExpr.Args) != 2 {

store/db/sqlite/memo_filter_test.go

+5
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,11 @@ func TestConvertExprToSQL(t *testing.T) {
6969
want: "NOT (`memo`.`pinned` IS TRUE)",
7070
args: []any{},
7171
},
72+
{
73+
filter: `creator_id == 101 || visibility in ["PUBLIC", "PRIVATE"]`,
74+
want: "(`memo`.`creator_id` = ? OR `memo`.`visibility` IN (?,?))",
75+
args: []any{int64(101), "PUBLIC", "PRIVATE"},
76+
},
7277
}
7378

7479
for _, tt := range tests {

0 commit comments

Comments
 (0)