import vapoursynth as vs
from vapoursynth import core

import re
from vsmlrt import RIFE

def RIFE_imp(clip: vs.VideoNode,multi,model,backend) -> vs.VideoNode:
    p = re.split('\\.|_',model)
    model_num = int(p[0]) * (10 if len(p[1])==1 else 100) + int(p[1])
    if model.find('lite') > 0:
        model_num = model_num*10 + 1
    if model.find('heavy') > 0:
        model_num = model_num*10 + 2
    if model_num < 30:
        model_num = 46
    ensemble = model.find('ensemble') > 0
    implementation = 2 if model.endswith('_v2') else 1

    backend.force_fp16 = True
    backend.tf32 = True
    backend.output_format = 1
    backend.use_cuda_graph = True
    backend.workspace = None

    return RIFE(clip,multi,1.0,None,None,None,model_num,backend,ensemble,False,implementation)
