import torch
from torch import nn
import numpy as np
from collections import OrderedDict
import math
import torch.nn.functional as F

########################
# define all activation functions that are not in pytorch library. 
########################
class Swish(nn.Module):
	def __init__(self):
		super().__init__()
		self.Sigmoid = nn.Sigmoid()
	def forward(self,x):
		return x*self.Sigmoid(x)

class Sine(nn.Module):
    def __init__(self, w0 = 30):
        self.w0 = w0
        super().__init__()

    def forward(self, input):
        return torch.sin(self.w0 * input)

class Sine_tw(nn.Module):
    def __init__(self, w0 = 30):
        super().__init__()
        self.w0 = nn.Parameter(torch.tensor([w0], dtype = torch.float32))
    def forward(self, input):
        return torch.sin(self.w0 * input)


########################
# denfine all the basic layers 
########################

class BatchLinear(nn.Linear):
    '''
    This is a linear transformation implemented manually. It also allows maually input parameters. 
    for initialization, (in_features, out_features) needs to be provided. 
    weight is of shape (out_features*in_features)
    bias is of shape (out_features)
    
    '''
    __doc__ = nn.Linear.__doc__
    def forward(self, input, params=None):
        if params is None:
            params = OrderedDict(self.named_parameters())

        bias = params.get('bias', None)
        weight = params['weight']
        # print('before multiply',input.shape)
        output = torch.matmul(input,weight.permute(*[i for i in range(len(weight.shape) - 2)], -1, -2) )
        # print('after multiply',output.shape)
        if not bias == None: 
            output += bias.unsqueeze(-2)
        # print('after bias',output.shape)
        return output


class MLP(nn.Module):
    '''
    MLP with different activation functions 
    '''
    def __init__(self, in_features, out_features, num_hidden_layers, hidden_features,
                 outermost_linear=True, nonlinearity='relu', weight_init=None, output_mode='single',
                 premap_mode = None, **kwargs):
        super().__init__()
        self.premap_mode = premap_mode
        if not self.premap_mode ==None: 
            self.premap_layer = FeatureMapping(in_features,mode = premap_mode, **kwargs)
            in_features = self.premap_layer.dim # update the nf in features 

        self.first_layer_init = None
        self.output_mode = output_mode

        # Dictionary that maps nonlinearity name to the respective function, initialization, and, if applicable,
        # special first-layer initialization scheme
        # different layers has different initialization schemes: 
        nls_and_inits = {'sine':(Sine(), sine_init, first_layer_sine_init), # act name, init func, first layer init func 
                         'relu':(nn.ReLU(inplace=True), init_weights_normal, None),
                         'sigmoid':(nn.Sigmoid(), init_weights_xavier, None),
                         'tanh':(nn.Tanh(), init_weights_xavier, None),
                         'selu':(nn.SELU(inplace=True), init_weights_selu, None),
                         'softplus':(nn.Softplus(), init_weights_normal, None),
                         'elu':(nn.ELU(inplace=True), init_weights_elu, None),
                         'swish':(Swish(), init_weights_xavier, None),
                         }
                        

        nl, nl_weight_init, first_layer_init = nls_and_inits[nonlinearity]

        if weight_init is not None:  # Overwrite weight init if passed
            self.weight_init = weight_init
        else:
            self.weight_init = nl_weight_init # those are default init funcs 

        self.net = []

        # append the first layer
        self.net.append(nn.Sequential(
            BatchLinear(in_features, hidden_features), nl
        ))
        
        # append the hidden layers
        for i in range(num_hidden_layers):
            self.net.append(nn.Sequential(
                BatchLinear(hidden_features, hidden_features), nl
            ))

        # append the last layer
        if outermost_linear:
            self.net.append(nn.Sequential(BatchLinear(hidden_features, out_features)))
        else:
            self.net.append(nn.Sequential(
                BatchLinear(hidden_features, out_features), nl
            ))

        # put them as a meta sequence
        self.net = nn.Sequential(*self.net)

        if self.weight_init is not None:
            self.net.apply(self.weight_init)

        if first_layer_init is not None: # Apply special initialization to first layer, if applicable.
            self.net[0].apply(first_layer_init)

    def forward(self, x):
        # propagate through the nf net, implementing SIREN
        if not self.premap_mode ==None: 
            x = self.premap_layer(x)

        if self.output_mode == 'single':
            output = self.net(x)
            return output

        elif self.output_mode=='double':
            x = x.clone().detach().requires_grad_(True)
            output = self.net(x)
            return output, x

class MLP_rezblk(nn.Module):
    '''
    MLP with different activation functions 
    '''
    def __init__(self, num_hidden_layers, hidden_features,
                 outermost_linear=True, nonlinearity='sine', weight_init=None, output_mode='single',
                 premap_mode = None, **kwargs):
        super().__init__()
        self.premap_mode = premap_mode
        if not self.premap_mode ==None: 
            self.premap_layer = FeatureMapping(in_features,mode = premap_mode, **kwargs)
            in_features = self.premap_layer.dim # update the nf in features 

        self.first_layer_init = None
        self.output_mode = output_mode

        # Dictionary that maps nonlinearity name to the respective function, initialization, and, if applicable,
        # special first-layer initialization scheme
        # different layers has different initialization schemes: 
        nls_and_inits = {'sine':(Sine(), sine_init, first_layer_sine_init), # act name, init func, first layer init func 
                         'relu':(nn.ReLU(inplace=True), init_weights_normal, None),
                         'sigmoid':(nn.Sigmoid(), init_weights_xavier, None),
                         'tanh':(nn.Tanh(), init_weights_xavier, None),
                         'selu':(nn.SELU(inplace=True), init_weights_selu, None),
                         'softplus':(nn.Softplus(), init_weights_normal, None),
                         'elu':(nn.ELU(inplace=True), init_weights_elu, None),
                         'swish':(Swish(), init_weights_xavier, None),
                         }
                        

        nl, nl_weight_init, first_layer_init = nls_and_inits[nonlinearity]

        if weight_init is not None:  # Overwrite weight init if passed
            self.weight_init = weight_init
        else:
            self.weight_init = nl_weight_init # those are default init funcs 

        self.net = []
        
        # append the hidden layers
        for i in range(num_hidden_layers):
            self.net.append(nn.Sequential(
                BatchLinear(hidden_features, hidden_features), nl
            ))

        # put them as a meta sequence
        self.net = nn.Sequential(*self.net)

        if self.weight_init is not None:
            self.net.apply(self.weight_init)

    def forward(self, x):
        # propagate through the nf net, implementing SIREN
        if not self.premap_mode ==None: 
            x = self.premap_layer(x)

        output = 0.5*self.net(x) + 0.5*x
        return output


class MLP_reznet(nn.Module):
    '''
    MLP with different activation functions with rez
    '''
    def __init__(self, in_features, out_features, num_hidden_blocks, hidden_features, num_hidden_layers_rez = 2,
                 outermost_linear=True, nonlinearity='relu', weight_init=None, output_mode='single',
                 premap_mode = None, **kwargs):
        super().__init__()

        self.premap_mode = premap_mode
        if not self.premap_mode ==None: 
            self.premap_layer = FeatureMapping(in_features,mode = premap_mode, **kwargs)
            in_features = self.premap_layer.dim # update the nf in features 

        self.first_layer_init = None
        self.output_mode = output_mode

        # Dictionary that maps nonlinearity name to the respective function, initialization, and, if applicable,
        # special first-layer initialization scheme
        # different layers has different initialization schemes: 
        nls_and_inits = {'sine':(Sine(), sine_init, first_layer_sine_init), # act name, init func, first layer init func 
                         'relu':(nn.ReLU(inplace=True), init_weights_normal, None),
                         'sigmoid':(nn.Sigmoid(), init_weights_xavier, None),
                         'tanh':(nn.Tanh(), init_weights_xavier, None),
                         'selu':(nn.SELU(inplace=True), init_weights_selu, None),
                         'softplus':(nn.Softplus(), init_weights_normal, None),
                         'elu':(nn.ELU(inplace=True), init_weights_elu, None),
                         'swish':(Swish(), init_weights_xavier, None),
                         }
                        

        nl, nl_weight_init, first_layer_init = nls_and_inits[nonlinearity]

        if weight_init is not None:  # Overwrite weight init if passed
            self.weight_init = weight_init
        else:
            self.weight_init = nl_weight_init # those are default init funcs 
        self.nl = nl

        # append the first layer
        self.fc1 = BatchLinear(in_features, hidden_features)

        # append the hidden layers
        self.rezblks  = [] 
        for i in range(self.num_hidden_blocks):
            self.rezblks.append(
                MLP_rezblk(num_hidden_layers_rez, hidden_features)
            )

        # append the last layer
        self.fc2 = BatchLinear(hidden_features, out_features)

        if self.weight_init is not None:
            self.fc2.apply(self.weight_init)

        if first_layer_init is not None: # Apply special initialization to first layer, if applicable.
            self.fc1.apply(first_layer_init)

    def forward(self, x):
        # propagate through the nf net, implementing SIREN
        if not self.premap_mode ==None: 
            x = self.premap_layer(x)

        output = self.nl(self.fc1(x))
        for blk in self.rezblks:
            output = blk(output)
        output = self.fc2(output)
        return output


class FeatureMapping():
    '''
    This is feature mapping class for  fourier feature networks 
    '''
    def __init__(self, in_features, mode = 'basic', 
                 gaussian_mapping_size = 256, gaussian_rand_key = 0, gaussian_tau = 1.,
                 pe_num_freqs = 4, pe_scale = 2, pe_init_scale = 1,pe_use_nyquist=True, pe_lowest_dim = None, 
                 rbf_out_features = None, rbf_range = 1., rbf_std=0.5):
        '''
        inputs:
            in_freatures: number of input features
            mapping_size: output features for Gaussian mapping
            rand_key: random key for Gaussian mapping
            tau: standard deviation for Gaussian mapping
            num_freqs: number of frequencies for P.E.
            scale = 2: base scale of frequencies for P.E.
            init_scale: initial scale for P.E.
            use_nyquist: use nyquist to calculate num_freqs or not. 
        
        '''
        self.mode = mode
        if mode == 'basic':
            self.B = np.eye(in_features)
        elif mode == 'gaussian':
            rng = np.random.default_rng(gaussian_rand_key)
            self.B = rng.normal(loc = 0., scale = gaussian_tau, size = (gaussian_mapping_size, in_features))
        elif mode == 'positional':
            if pe_use_nyquist == 'True' and pe_lowest_dim:  
                pe_num_freqs = self.get_num_frequencies_nyquist(pe_lowest_dim)
            self.B = pe_init_scale * np.vstack([(pe_scale**i)* np.eye(in_features) for i in range(pe_num_freqs)])
            self.dim = self.B.shape[0]*2
        elif mode == 'rbf':
            self.centers = nn.Parameter(torch.empty((rbf_out_features, in_features), dtype = torch.float32))
            self.sigmas = nn.Parameter(torch.empty(rbf_out_features, dtype = torch.float32))
            nn.init.uniform_(self.centers, -1*rbf_range, rbf_range)
            nn.init.constant_(self.sigmas, rbf_std)

    def __call__(self, input):
        if self.mode in ['basic', 'gaussian', 'positional']: 
            return self.fourier_mapping(input,self.B)
        elif self.mode =='rbf':
            return self.rbf_mapping(input)
            
    def get_num_frequencies_nyquist(self, samples):
        nyquist_rate = 1 / (2 * (2 * 1 / samples))
        return int(math.floor(math.log(nyquist_rate, 2)))

    # Fourier feature mapping
    @staticmethod
    def fourier_mapping(x, B):
        '''
        x is the input, B is the reference information 
        '''
        if B is None:
            return x
        else:
            B = torch.tensor(B, dtype = torch.float32, device = x.device)
            x_proj = (2.*np.pi*x) @ B.T
            return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
        
    # rbf mapping
    def rbf_mapping(self, x):

        size = (x.shape[:-1])+ self.centers.shape
        x = x.unsqueeze(-2).expand(size)
        # c = self.centres.unsqueeze(0).expand(size)
        # distances = (x - self.centers).pow(2).sum(-1) * self.sigmas.unsqueeze(0)
        distances = (x - self.centers).pow(2).sum(-1) * self.sigmas
        return self.gaussian(distances)
    
    @staticmethod
    def gaussian(alpha):
        phi = torch.exp(-1 * alpha.pow(2))
        return phi

# multiplicative 
class MFNBase(nn.Module):
    """
    Multiplicative filter network base class.

    Expects the child class to define the 'filters' attribute, which should be 
    a nn.ModuleList of num_hidden_layers+1 filters with output equal to hidden_features.
    """

    def __init__(self, hidden_features, out_features, num_hidden_layers, weight_scale, bias=True, output_act=False):
        super().__init__()


        self.linear = nn.ModuleList(
            [nn.Linear(hidden_features, hidden_features, bias) for _ in range(num_hidden_layers)]
        )
        self.output_linear = nn.Linear(hidden_features, out_features)
        self.output_act = output_act

        # for lin in self.linear:
        #     lin.weight.data.uniform_(
        #         -np.sqrt(weight_scale / hidden_features),
        #         np.sqrt(weight_scale / hidden_features),
        #     )

        # kaimin init
        for lin in self.linear:
            nn.init.kaiming_uniform_(lin.weight, a = math.sqrt(5))

    def forward(self, x):
        if not self.premap_mode ==None: 
            x = self.premap_layer(x)

        out = self.filters[0](x)
        for i in range(1, len(self.filters)):
            out = self.filters[i](x) * self.linear[i - 1](out)
        out = self.output_linear(out)

        if self.output_act:
            out = torch.sin(out)

        return out


class FourierLayer(nn.Module):
    """
    Sine filter as used in FourierNet.
    """

    def __init__(self, in_features, out_features, weight_scale):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.linear.weight.data *= weight_scale  # gamma
        self.linear.bias.data.uniform_(-np.pi, np.pi)
        return

    def forward(self, x):
        return torch.sin(self.linear(x))


class FourierNet(MFNBase):
    def __init__(
        self,
        in_features,
        out_features,
        num_hidden_layers,
        hidden_features,
        input_scale=256.0,
        weight_scale=1.0,
        bias=True,
        output_act=False,
        premap_mode = None, 
        **kwargs
    ):
        super().__init__(
            hidden_features, out_features, num_hidden_layers, weight_scale, bias, output_act
        )
        self.premap_mode = premap_mode
        if not self.premap_mode ==None: 
            self.premap_layer = FeatureMapping(in_features,mode = premap_mode, **kwargs)
        in_features = self.premap_layer.dim # update the nf in features 

        self.filters = nn.ModuleList(
            [
                FourierLayer(in_features, hidden_features, input_scale / np.sqrt(num_hidden_layers + 1))
                for _ in range(num_hidden_layers + 1)
            ]
        )

class GaborLayer(nn.Module):
    """
    Gabor-like filter as used in GaborNet.
    """

    def __init__(self, in_features, out_features, weight_scale, alpha=1.0, beta=1.0):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.mu = nn.Parameter(2 * torch.rand(out_features, in_features) - 1)
        self.gamma = nn.Parameter(
            torch.distributions.gamma.Gamma(alpha, beta).sample((out_features,))
        )
        self.linear.weight.data *= weight_scale * torch.sqrt(self.gamma[:, None])
        self.linear.bias.data.uniform_(-np.pi, np.pi)


    def forward(self, x):
        D = (
            (x ** 2).sum(-1)[..., None]
            + (self.mu ** 2).sum(-1)[None, :]
            - 2 * x @ self.mu.T
        )
        return torch.sin(self.linear(x)) * torch.exp(-0.5 * D * self.gamma[None, :])


class GaborNet(MFNBase):
    def __init__(
        self,
        in_features,
        out_features,
        num_hidden_layers,
        hidden_features,
        input_scale=256.0,
        weight_scale=1.0,
        alpha=6.0,
        beta=1.0,
        bias=True,
        output_act=False,
        premap_mode = None, 
        **kwargs
    ):
        super().__init__(
            hidden_features, out_features, num_hidden_layers, weight_scale, bias, output_act
        )

        self.premap_mode = premap_mode
        if not self.premap_mode ==None: 
            self.premap_layer = FeatureMapping(in_features,mode = premap_mode, **kwargs)
        in_features = self.premap_layer.dim # update the nf in features 

        self.filters = nn.ModuleList(
            [
                GaborLayer(
                    in_features,
                    hidden_features,
                    input_scale / np.sqrt(num_hidden_layers + 1),
                    alpha / (num_hidden_layers + 1),
                    beta,
                )
                for _ in range(num_hidden_layers + 1)
            ]
        )

########################
# conditioned neural field  
########################

# auto encoder: full projection 
# siren based 
# auto decoder: full projection 
# in construction 
class SIRENAutoencoder_fp(nn.Module):
    '''
    siren network with author decoding 
    '''
    def __init__(self, hyper_in_features,  hyper_latent_features, hyper_num_hidden_layers, hyper_hidden_features,
                 nf_in_features, out_features , nf_num_hidden_layers, nf_hidden_features, hyper_nonlinearity = 'sine',nf_nonlinearity= 'sine',
                 omega_0_e = 30., omega_0 = 30. ,
                 premap_mode = None, **kwargs):
        super().__init__()
        self.premap_mode = premap_mode
        if not self.premap_mode ==None: 
            self.premap_layer = FeatureMapping(nf_in_features,mode = premap_mode, **kwargs)
        nf_in_features = self.premap_layer.dim # update the nf in features 

        self.nf_in_features = nf_in_features
        self.out_features = out_features
        self.nf_num_hidden_layers = nf_num_hidden_layers
        self.nf_hidden_features = nf_hidden_features

        self.hyper_latent_features = hyper_latent_features
        self.hyper_out_features = self.calculate_num_parameters(nf_in_features, out_features , nf_num_hidden_layers, nf_hidden_features)
        self.omega_0_e = omega_0_e 
        self.omega_0  = omega_0 
        self.first_layer_init = None
        # Dictionary that maps nonlinearity name to the respective function, initialization, and, if applicable,
        # special first-layer initialization scheme
        # different layers has different initialization schemes: 
        nls_and_inits = {'sine':(Sine(), sine_init, first_layer_sine_init), # act name, init func, first layer init func 
                         'relu':(nn.ReLU(inplace=True), init_weights_normal, None),
                         'sigmoid':(nn.Sigmoid(), init_weights_xavier, None),
                         'tanh':(nn.Tanh(), init_weights_xavier, None),
                         'selu':(nn.SELU(inplace=True), init_weights_selu, None),
                         'softplus':(nn.Softplus(), init_weights_normal, None),
                         'elu':(nn.ELU(inplace=True), init_weights_elu, None),
                         'swish':(Swish(), init_weights_xavier, None),
                         }
                        
        self.hyper_nl, self.hyper_nl_weight_init, self.hyper_first_layer_init = nls_and_inits[hyper_nonlinearity]
        self.nf_nl, _,_ = nls_and_inits[nf_nonlinearity]
        # set omega:
        self.nf_nl.w0 = omega_0

        self.hyper_net = nn.ModuleList([BatchLinear(hyper_in_features,hyper_hidden_features)] + 
                                  [BatchLinear(hyper_hidden_features,hyper_hidden_features) for i in range(hyper_num_hidden_layers)] + 
                                  [BatchLinear(hyper_hidden_features,hyper_latent_features)])

        if self.hyper_nl_weight_init is not None: 
            self.hyper_net.apply(self.hyper_nl_weight_init)

        if self.hyper_first_layer_init is not None: # Apply special initialization to first layer, if applicable.
            self.hyper_net[0].apply(self.hyper_first_layer_init)
        # print('init hyper dims: ', hyper_latent_features,self.hyper_out_features)
        self.hyper_net_last = BatchLinear(hyper_latent_features,self.hyper_out_features)
        # print('init hyper weight shape: ', self.hyper_net_last.weight.data.shape)
        self.init_hyper_layer()

    def forward(self, coords, priors):
        # coords: <t,h,w,c_coord>
        # latents: <t,h,w,c_priors>
        x = priors
        for i in range(len(self.hyper_net) -1):
            x = self.hyper_net[i](x)
            x = self.hyper_nl(x)
        latent = self.hyper_net[-1](x)

        print(latent.shape, self.hyper_net_last.weight.data.shape)
        params = self.hyper_net_last(latent)
        print('param shape:', params.shape)
        
        # fill x into nf net, full projection 
        cursors = [0]
        cursors.append(cursors[-1] + self.nf_in_features * self.nf_hidden_features)
        w1 = params[..., cursors[-2]:cursors[-1]].reshape(params.shape[:-1]+(self.nf_in_features,self.nf_hidden_features))
        hidden_ws = []
        for i in range(self.nf_num_hidden_layers):
            cursors.append(cursors[-1] + self.nf_hidden_features*self.nf_hidden_features)
            hidden_ws.append(params[..., cursors[-2]: cursors[-1]].reshape(params.shape[:-1]+(self.nf_hidden_features,self.nf_hidden_features)))
        
        cursors.append(cursors[-1] + self.nf_hidden_features*self.out_features)
        w2 = params[..., cursors[-2]:cursors[-1]].reshape(params.shape[:-1]+(self.nf_hidden_features,self.out_features))
        
        # fill bias
        cursors.append(cursors[-1] + self.nf_hidden_features)
        b1 = params[..., cursors[-2]: cursors[-1]]

        hidden_bs = []
        for i in range(self.nf_num_hidden_layers):
            cursors.append(cursors[-1] + self.nf_hidden_features)
            hidden_bs.append(params[..., cursors[-2]: cursors[-1]])
        
        b2 = params[..., cursors[-1]:]
        
        # propagate through the nf net, implementing SIREN
        if not self.premap_mode ==None: 
            out = self.premap_layer(coords)
        else: 
            out = coords

        # pass it through the nf network 
        out = self.nf_nl(torch.einsum('thwi,thwij->thwj', out, w1) + b1)

        for i in range(self.nf_num_hidden_layers):
            # out = torch.sin(self.omega_0 * torch.einsum('thwi,thwij->thwj', out, hidden_ws[i]) + hidden_bs[i])
            out = self.nf_nl(torch.einsum('thwi,thwij->thwj', out, hidden_ws[i]) + hidden_bs[i])
        
        print('ensum shapes: ' , out.shape, w2.shape, b2.shape)
        out = torch.einsum('thwi,thwij->thwj', out, w2) + b2
        return out , latent, params
    
    
    def init_hyper_layer(self):
        # init weights
        w_init = nn.init.uniform_(torch.empty(self.hyper_out_features, self.hyper_latent_features),a=-np.sqrt(6.0 / self.hyper_latent_features) * 1e-2, b=np.sqrt(6.0 / self.hyper_latent_features) * 1e-2)
        # init bias
        tmp = torch.ones((self.hyper_out_features))
        tmp[:self.nf_in_features * self.nf_hidden_features] = tmp[:self.nf_in_features * self.nf_hidden_features] * 1.0 / self.nf_in_features  # 1st layer weights
        temp_num_weights =self.nf_in_features * self.nf_hidden_features+ self.nf_num_hidden_layers*self.nf_hidden_features*self.nf_hidden_features + self.nf_hidden_features*self.out_features
        tmp[self.nf_in_features * self.nf_hidden_features:temp_num_weights] = tmp[self.nf_in_features * self.nf_hidden_features:temp_num_weights] * \
                                                            np.sqrt(6.0 /self.nf_hidden_features) / self.omega_0_e  # other layer weights
        tmp[temp_num_weights:] = 1.0 / self.nf_hidden_features  # all biases
        b_init = torch.distributions.uniform.Uniform(low= -tmp,high= tmp).sample().squeeze(-1)   # uniformly distributed in the range

        with torch.no_grad():
            self.hyper_net_last.weight.data = w_init 
            self.hyper_net_last.bias.data = b_init 

    @staticmethod
    def calculate_num_parameters(nf_in_features, out_features , nf_num_hidden_layers, nf_hidden_features):
        return (nf_in_features+1) * nf_hidden_features+ nf_num_hidden_layers*(nf_hidden_features+1)*nf_hidden_features + (nf_hidden_features+1)*out_features

class SIRENAutoencoder_film(nn.Module):
    '''
    siren network with author decoding 
    '''
    def __init__(self, hyper_in_features,  hyper_latent_features, hyper_num_hidden_layers, hyper_hidden_features,
                 nf_in_features, out_features , nf_num_hidden_layers, nf_hidden_features, hyper_nonlinearity = 'sine',nf_nonlinearity= 'sine',
                 omega_0_e = 30., omega_0 = 30.,
                 premap_mode = None, **kwargs):
        super().__init__()
        self.premap_mode = premap_mode
        if not self.premap_mode ==None: 
            self.premap_layer = FeatureMapping(nf_in_features,mode = premap_mode, **kwargs)
            nf_in_features = self.premap_layer.dim # update the nf in features 

        self.nf_in_features = nf_in_features
        self.out_features = out_features
        self.nf_num_hidden_layers = nf_num_hidden_layers
        self.nf_hidden_features = nf_hidden_features

        self.hyper_latent_features = hyper_latent_features
        self.omega_0_e = omega_0_e 
        self.omega_0  = omega_0 
        self.first_layer_init = None
        # Dictionary that maps nonlinearity name to the respective function, initialization, and, if applicable,
        # special first-layer initialization scheme
        # different layers has different initialization schemes: 
        nls_and_inits = {'sine':(Sine(), sine_init, first_layer_sine_init), # act name, init func, first layer init func 
                         'relu':(nn.ReLU(inplace=True), init_weights_normal, None),
                         'sigmoid':(nn.Sigmoid(), init_weights_xavier, None),
                         'tanh':(nn.Tanh(), init_weights_xavier, None),
                         'selu':(nn.SELU(inplace=True), init_weights_selu, None),
                         'softplus':(nn.Softplus(), init_weights_normal, None),
                         'elu':(nn.ELU(inplace=True), init_weights_elu, None),
                         'swish':(Swish(), init_weights_xavier, None),
                         }
                        
        self.hyper_nl, self.hyper_nl_weight_init, self.hyper_first_layer_init = nls_and_inits[hyper_nonlinearity]
        self.nf_nl, self.nf_nl_weight_init, self.nf_first_layer_init = nls_and_inits[nf_nonlinearity]
        # set omega:
        self.nf_nl.w0 = omega_0

        self.hyper_net = nn.ModuleList([BatchLinear(hyper_in_features,hyper_hidden_features)] + 
                                  [BatchLinear(hyper_hidden_features,hyper_hidden_features) for i in range(hyper_num_hidden_layers)] + 
                                  [BatchLinear(hyper_hidden_features,hyper_latent_features)])

        self.proj_net = nn.ModuleList([BatchLinear(hyper_latent_features,nf_hidden_features,bias = False) for i in range(nf_num_hidden_layers+1)])
        
        self.nf_net = nn.ModuleList([BatchLinear(nf_in_features,nf_hidden_features)] + 
                                  [BatchLinear(nf_hidden_features,nf_hidden_features) for i in range(nf_num_hidden_layers)] + 
                                  [BatchLinear(nf_hidden_features,out_features)]) 
    
        # initialize the hyper network 
        if self.hyper_nl_weight_init is not None: 
            self.hyper_net.apply(self.hyper_nl_weight_init)
            self.proj_net.apply(self.hyper_nl_weight_init)

        if self.hyper_first_layer_init is not None: # Apply special initialization to first layer, if applicable.
            self.hyper_net[0].apply(self.hyper_first_layer_init)
            self.proj_net[0].apply(self.hyper_first_layer_init)

        # initialize the hyper network 
        if self.nf_nl_weight_init is not None: 
            self.nf_net.apply(self.nf_nl_weight_init)

        if self.nf_first_layer_init is not None: # Apply special initialization to first layer, if applicable.
            self.nf_net[0].apply(self.nf_first_layer_init)


    def forward(self, coords, priors):
        # coords: <t,h,w,c_coord>
        # latents: <t,h,w,c_priors>
        p = priors
        for i in range(len(self.hyper_net) -1):
            p = self.hyper_net[i](p)
            p = self.hyper_nl(p)
        latents = self.hyper_net[-1](p)


        # propagate through the nf net, implementing SIREN
        if not self.premap_mode ==None:  
            x = self.premap_layer(coords)
        else: 
            x = coords

        # pass it through the nf network
        for i in range(len(self.nf_net) -1):
            # print('net debug:', self.nf_net[i](x).shape, self.proj_net[i](latents).shape)
            x = self.nf_net[i](x) + self.proj_net[i](latents)
            x = self.nf_nl(x)
        out = self.nf_net[-1](x)
        return out,  latents


# class MFNAutoencoder_fp(nn.Module):
#     '''
#     MFN network with author encoding 
#     '''
#     def __init__(self, hyper_in_features,  hyper_latent_features, hyper_num_hidden_layers, hyper_hidden_features,
#                  nf_in_features, out_features , nf_num_hidden_layers, nf_hidden_features,
#                  weight_scale = 1.,
#                  hyper_alpha = 6.0, hyper_beta = 1.0, hyper_input_scale = 256.0,
#                  nf_alpha = 6.0, nf_beta = 1.0, nf_input_scale = 256.0):
#         super().__init__()
#         self.nf_in_features = nf_in_features
#         self.out_features = out_features
#         self.nf_num_hidden_layers = nf_num_hidden_layers
#         self.nf_hidden_features = nf_hidden_features
#         self.weight_scale = weight_scale
#         self.hyper_latent_features = hyper_latent_features
#         self.hyper_out_features = self.calculate_num_parameters(nf_in_features, out_features , nf_num_hidden_layers, nf_hidden_features)
      
#         self.hyper_net = nn.ModuleList([BatchLinear(hyper_hidden_features,hyper_hidden_features) for i in range(hyper_num_hidden_layers)])
#         self.hyper_net_outputlinear = BatchLinear(hyper_hidden_features,hyper_latent_features)        

#         self.hyper_nl_weight_init,self.hyper_first_layer_init = init_weights_uniform_mfn, init_weights_uniform_mfn

#         self.hyper_filters = nn.ModuleList(
#             [
#                 GaborLayer(
#                     hyper_in_features,
#                     hyper_hidden_features,
#                     hyper_input_scale / np.sqrt(hyper_num_hidden_layers + 1),
#                     hyper_alpha / (hyper_num_hidden_layers + 1),
#                     hyper_beta,
#                 )
#                 for _ in range(hyper_num_hidden_layers + 1)
#             ])

#         self.nf_filters = nn.ModuleList(
#             [
#                 GaborLayer(
#                     nf_in_features,
#                     nf_hidden_features,
#                     nf_input_scale / np.sqrt(nf_num_hidden_layers + 1),
#                     nf_alpha / (nf_num_hidden_layers + 1),
#                     nf_beta,
#                 )
#                 for _ in range(nf_num_hidden_layers + 1)
#             ])

#         if self.hyper_nl_weight_init is not None: 
#             self.hyper_net.apply(self.hyper_nl_weight_init)

#         if self.hyper_first_layer_init is not None: # Apply special initialization to first layer, if applicable.
#             self.hyper_net[0].apply(self.hyper_nl_weight_init)
#         # print('init hyper dims: ', hyper_latent_features,self.hyper_out_features)
#         self.hyper_net_last = BatchLinear(hyper_latent_features,self.hyper_out_features)
#         # print('init hyper weight shape: ', self.hyper_net_last.weight.data.shape)
#         self.init_hyper_layer()

#     def forward(self, coords, priors):
#         # coords: <t,h,w,c_coord>
#         # latents: <t,h,w,c_priors>
#         x = self.hyper_filters[0](priors)
#         print('filter output shape:', x.shape)
#         for i in range(len(self.hyper_net) ):
#             x = self.hyper_net[i](x)*self.hyper_filters[i+1](priors)
#         latents = self.hyper_net_outputlinear(x)

#         print(latents.shape, self.hyper_net_last.weight.data.shape)
#         params = self.hyper_net_last(latents)
#         print('param shape:', params.shape)
        
#         # fill x into nf net, full projection 
#         cursors = [0]
#         cursors.append(cursors[-1] + self.nf_in_features * self.nf_hidden_features)
#         w1 = params[..., cursors[-2]:cursors[-1]].reshape(params.shape[:-1]+(self.nf_in_features,self.nf_hidden_features))
#         hidden_ws = []
#         for i in range(self.nf_num_hidden_layers):
#             cursors.append(cursors[-1] + self.nf_hidden_features*self.nf_hidden_features)
#             hidden_ws.append(params[..., cursors[-2]: cursors[-1]].reshape(params.shape[:-1]+(self.nf_hidden_features,self.nf_hidden_features)))
        
#         cursors.append(cursors[-1] + self.nf_hidden_features*self.out_features)
#         w2 = params[..., cursors[-2]:cursors[-1]].reshape(params.shape[:-1]+(self.nf_hidden_features,self.out_features))
        
#         # fill bias
#         cursors.append(cursors[-1] + self.nf_hidden_features)
#         b1 = params[..., cursors[-2]: cursors[-1]]

#         hidden_bs = []
#         for i in range(self.nf_num_hidden_layers):
#             cursors.append(cursors[-1] + self.nf_hidden_features)
#             hidden_bs.append(params[..., cursors[-2]: cursors[-1]])
        
#         b2 = params[..., cursors[-1]:]
        
#         h = self.nf_filters[0](coords)
#         # propagate through the nf net, implementing SIREN
#         h = (torch.einsum('thwi,thwij->thwj', h, w1) + b1)*self.nf_filters[1](coords)

#         for i in range(self.nf_num_hidden_layers):
#             # out = torch.sin(self.omega_0 * torch.einsum('thwi,thwij->thwj', out, hidden_ws[i]) + hidden_bs[i])
#             h = (torch.einsum('thwi,thwij->thwj', h, hidden_ws[i]) + hidden_bs[i])*self.nf_filters[i+2](coords)
        
#         print('ensum shapes: ' , out.shape, w2.shape, b2.shape)
#         out = torch.einsum('thwi,thwij->thwj', h, w2) + b2
#         return out , latents, params
    
    
    # def init_hyper_layer(self):
    #     # init weights
    #     w_init = nn.init.uniform_(torch.empty(self.hyper_out_features, self.hyper_latent_features),a=-np.sqrt(self.weight_scale / self.hyper_latent_features) * 1e-2, b=np.sqrt(self.weight_scale / self.hyper_latent_features) * 1e-2)
    #     # init bias
    #     tmp = torch.ones((self.hyper_out_features))
    #     tmp[:self.nf_in_features * self.nf_hidden_features] = tmp[:self.nf_in_features * self.nf_hidden_features] *np.sqrt(self.weight_scale/self.nf_hidden_features)  # 1st layer weights
    #     temp_num_weights =self.nf_in_features * self.nf_hidden_features+ self.nf_num_hidden_layers*self.nf_hidden_features*self.nf_hidden_features + self.nf_hidden_features*self.out_features
    #     tmp[self.nf_in_features * self.nf_hidden_features:temp_num_weights] = tmp[self.nf_in_features * self.nf_hidden_features:temp_num_weights] * \
    #                                                         np.sqrt(self.weight_scale/self.nf_hidden_features)  # other layer weights
    #     tmp[temp_num_weights:] = np.sqrt(1.0 / self.nf_hidden_features ) # all biases
    #     b_init = torch.distributions.uniform.Uniform(low= -tmp,high= tmp).sample().squeeze(-1)   # uniformly distributed in the range

    #     with torch.no_grad():
    #         self.hyper_net_last.weight.data = w_init 
    #         self.hyper_net_last.bias.data = b_init 

    # @staticmethod
    # def calculate_num_parameters(nf_in_features, out_features , nf_num_hidden_layers, nf_hidden_features):
    #     return (nf_in_features+1) * nf_hidden_features+ nf_num_hidden_layers*(nf_hidden_features+1)*nf_hidden_features + (nf_hidden_features+1)*out_features




# auto encoder: bias shift 
# auto encoder: weight and shift bias 
# auto encoder: filter

# siren auto decoder: FILM
class SIRENAutodecoder_film_single(nn.Module):
    '''
    siren network with author decoding 
    '''
    def __init__(self, in_coord_features, out_features, num_hidden_layers, hidden_features,
                 outermost_linear=False, nonlinearity='sine', weight_init=None,bias_init=None, #forward_mode = 'without_latent',
                 premap_mode = None, **kwargs):
        super().__init__()
        self.premap_mode = premap_mode
        if not self.premap_mode ==None: 
            self.premap_layer = FeatureMapping(in_coord_features,mode = premap_mode, **kwargs)
            in_coord_features = self.premap_layer.dim # update the nf in features     
        # self.forward_mode = forward_mode # with_latent, without_latent
        self.first_layer_init = None
        # Dictionary that maps nonlinearity name to the respective function, initialization, and, if applicable,
        # special first-layer initialization scheme
        # different layers has different initialization schemes: 
        nls_and_inits = {'sine':(Sine(), sine_init, first_layer_sine_init), # act name, init func, first layer init func 
                         'relu':(nn.ReLU(inplace=True), init_weights_normal, None),
                         'sigmoid':(nn.Sigmoid(), init_weights_xavier, None),
                         'tanh':(nn.Tanh(), init_weights_xavier, None),
                         'selu':(nn.SELU(inplace=True), init_weights_selu, None),
                         'softplus':(nn.Softplus(), init_weights_normal, None),
                         'elu':(nn.ELU(inplace=True), init_weights_elu, None),
                         'swish':(Swish(), init_weights_xavier, None),
                         }
                        
        self.nl, nl_weight_init, first_layer_init = nls_and_inits[nonlinearity]

        if weight_init is not None:  # Overwrite weight init if passed
            self.weight_init = weight_init
        else:
            self.weight_init = nl_weight_init # those are default init funcs 

        self.net1 = nn.ModuleList([BatchLinear(in_coord_features,hidden_features)] + 
                                  [BatchLinear(hidden_features,hidden_features) for i in range(num_hidden_layers)] + 
                                  [BatchLinear(hidden_features,out_features)])

        if self.weight_init is not None:
            self.net1.apply(self.weight_init)

        if first_layer_init is not None: # Apply special initialization to first layer, if applicable.
            self.net1[0].apply(first_layer_init)


    def forward(self, coords, all_latents=None):
        # coords: <t,h,w,c_coord>
        # latents: <t,h,w,c_latent>
        
        # premap 
        if not self.premap_mode ==None: 
            x = self.premap_layer(coords)
        else: 
            x = coords

        if all_latents == None:
            # pass it through  the nf network 
            for i in range(len(self.net1) -1):
                x = self.net1[i](x)
                x = self.nl(x)
            x = self.net1[-1](x)
            return x 
        else: 
            # pass it through  the nf network 
            for i in range(len(self.net1) -1):
                x = self.net1[i](x) + all_latents[i]
                x = self.nl(x)
            x = self.net1[-1](x)
            return x 

    def forward_with_latent(self, coords, latents):
        # coords: <t,h,w,c_coord>
        # latents: <t,h,w,c_latent>
        
        # premap 
        if not self.premap_mode ==None: 
            x = self.premap_layer(coords)
        else: 
            x = coords

        # pass it through  the nf network 
        for i in range(len(self.net1) -1):
            x = self.net1[i](x) + self.net2[i](latents)
            x = self.nl(x)
        x = self.net1[-1](x)
        return x 

# siren auto decoder: FILM
class SIREN_rez_Autodecoder_film(nn.Module):
    '''
    siren network with author decoding 
    '''
    def __init__(self, in_coord_features, in_latent_features, out_features, num_hidden_blocks, hidden_features,num_hidden_layers_rez = 2,
                 outermost_linear=False, nonlinearity='sine', weight_init=None,bias_init=None,
                 premap_mode = None, **kwargs):
        super().__init__()

        self.num_hidden_blocks = num_hidden_blocks
        # self.hidden_features = hidden_features
        self.num_hidden_layers_rez = num_hidden_layers_rez

        self.premap_mode = premap_mode
        if not self.premap_mode ==None: 
            self.premap_layer = FeatureMapping(in_coord_features,mode = premap_mode, **kwargs)
            in_coord_features = self.premap_layer.dim # update the nf in features     

        self.first_layer_init = None
        # Dictionary that maps nonlinearity name to the respective function, initialization, and, if applicable,
        # special first-layer initialization scheme
        # different layers has different initialization schemes: 
        nls_and_inits = {'sine':(Sine(), sine_init, first_layer_sine_init), # act name, init func, first layer init func 
                         'relu':(nn.ReLU(inplace=True), init_weights_normal, None),
                         'sigmoid':(nn.Sigmoid(), init_weights_xavier, None),
                         'tanh':(nn.Tanh(), init_weights_xavier, None),
                         'selu':(nn.SELU(inplace=True), init_weights_selu, None),
                         'softplus':(nn.Softplus(), init_weights_normal, None),
                         'elu':(nn.ELU(inplace=True), init_weights_elu, None),
                         'swish':(Swish(), init_weights_xavier, None),
                         }
                        
        self.nl, nl_weight_init, first_layer_init = nls_and_inits[nonlinearity]

        if weight_init is not None:  # Overwrite weight init if passed
            self.weight_init = weight_init
        else:
            self.weight_init = nl_weight_init # those are default init funcs 

        # create the neural field
        # append the first layer
        self.net1 = nn.ModuleList([BatchLinear(in_coord_features,hidden_features)] + 
                                  [BatchLinear(hidden_features,hidden_features) for i in range(num_hidden_blocks*num_hidden_layers_rez)] + 
                                  [BatchLinear(hidden_features,out_features)])
        
        self.net2 = nn.ModuleList([BatchLinear(in_latent_features,hidden_features,bias = False) for i in range(num_hidden_blocks*num_hidden_layers_rez+1)])

        if self.weight_init is not None:
            self.net1.apply(self.weight_init)
            self.net2.apply(self.weight_init)

        if first_layer_init is not None: # Apply special initialization to first layer, if applicable.
            self.net1[0].apply(first_layer_init)
            self.net2[0].apply(first_layer_init)
        if bias_init is not None:
            self.net2.apply(bias_init)

    def forward(self, coords, latents):
        # coords: <t,h,w,c_coord>
        # latents: <t,h,w,c_latent>
        
        # premap 
        if not self.premap_mode ==None: 
            x = self.premap_layer(coords)
        else: 
            x = coords

        # pass through first layer
        x = self.net1[0](x) + self.net2[0](latents)
        x = self.nl(x)

        # pass it through the nf network 
        for i in range(self.num_hidden_blocks):
            x0 = x 
            for j in range(self.num_hidden_layers_rez):
                ij = 1+i*self.num_hidden_layers_rez+j
                x = self.net1[ij](x) + self.net2[ij](latents)
                x = self.nl(x)
            x = 0.5*x0 + 0.5*x
        
        # pass it through the last layer
        x = self.net1[-1](x)
        return x 

# siren auto decoder: FILM
class SIRENAutodecoder_tw_film(nn.Module):
    '''
    siren network with author decoding 
    '''
    def __init__(self, in_coord_features, in_latent_features, out_features, num_hidden_layers, hidden_features,
                 outermost_linear=False, nonlinearity='sine_tw', weight_init=None,bias_init=None,w0_init = 30., 
                 premap_mode = None, **kwargs):
        super().__init__()
        self.premap_mode = premap_mode
        if not self.premap_mode ==None: 
            self.premap_layer = FeatureMapping(in_coord_features,mode = premap_mode, **kwargs)
            in_coord_features = self.premap_layer.dim # update the nf in features     

        self.first_layer_init = None
        # Dictionary that maps nonlinearity name to the respective function, initialization, and, if applicable,
        # special first-layer initialization scheme
        # different layers has different initialization schemes: 
        nls_and_inits = {'sine_tw': (Sine_tw(w0 = w0_init), sine_init, first_layer_sine_init),
                         'sine':(Sine(), sine_init, first_layer_sine_init), # act name, init func, first layer init func 
                         'relu':(nn.ReLU(inplace=True), init_weights_normal, None),
                         'sigmoid':(nn.Sigmoid(), init_weights_xavier, None),
                         'tanh':(nn.Tanh(), init_weights_xavier, None),
                         'selu':(nn.SELU(inplace=True), init_weights_selu, None),
                         'softplus':(nn.Softplus(), init_weights_normal, None),
                         'elu':(nn.ELU(inplace=True), init_weights_elu, None),
                         'swish':(Swish(), init_weights_xavier, None),
                         }
                        
        self.nl, nl_weight_init, first_layer_init = nls_and_inits[nonlinearity]

        if weight_init is not None:  # Overwrite weight init if passed
            self.weight_init = weight_init
        else:
            self.weight_init = nl_weight_init # those are default init funcs 

        self.net1 = nn.ModuleList([BatchLinear(in_coord_features,hidden_features)] + 
                                  [BatchLinear(hidden_features,hidden_features) for i in range(num_hidden_layers)] + 
                                  [BatchLinear(hidden_features,out_features)])
        self.net2 = nn.ModuleList([BatchLinear(in_latent_features,hidden_features,bias = False) for i in range(num_hidden_layers+1)])

        if self.weight_init is not None:
            # self.net1.apply(self.weight_init)
            # self.net2.apply(self.weight_init)
            self.weight_init(self.net1, w0 = w0_init)
            self.weight_init(self.net2, w0 = w0_init)


        if first_layer_init is not None: # Apply special initialization to first layer, if applicable.
            self.net1[0].apply(first_layer_init)
            self.net2[0].apply(first_layer_init)

        if bias_init is not None:
            self.net2.apply(bias_init)

    def forward(self, coords, latents):
        # coords: <t,h,w,c_coord>
        # latents: <t,h,w,c_latent>
        
        # premap 
        if not self.premap_mode ==None: 
            x = self.premap_layer(coords)
        else: 
            x = coords

        # pass it through  the nf network 
        for i in range(len(self.net1) -1):
            x = self.net1[i](x) + self.net2[i](latents)
            x = self.nl(x)
        x = self.net1[-1](x)
        return x 

# siren auto decoder: FILM
class SIRENAutodecoder_film(nn.Module):
    '''
    siren network with author decoding 
    '''
    def __init__(self, in_coord_features, in_latent_features, out_features, num_hidden_layers, hidden_features,
                 outermost_linear=False, nonlinearity='sine', weight_init=None,bias_init=None,
                 premap_mode = None, **kwargs):
        super().__init__()
        self.premap_mode = premap_mode
        if not self.premap_mode ==None: 
            self.premap_layer = FeatureMapping(in_coord_features,mode = premap_mode, **kwargs)
            in_coord_features = self.premap_layer.dim # update the nf in features     

        self.first_layer_init = None
        # Dictionary that maps nonlinearity name to the respective function, initialization, and, if applicable,
        # special first-layer initialization scheme
        # different layers has different initialization schemes: 
        nls_and_inits = {'sine':(Sine(), sine_init, first_layer_sine_init), # act name, init func, first layer init func 
                         'relu':(nn.ReLU(inplace=True), init_weights_normal, None),
                         'sigmoid':(nn.Sigmoid(), init_weights_xavier, None),
                         'tanh':(nn.Tanh(), init_weights_xavier, None),
                         'selu':(nn.SELU(inplace=True), init_weights_selu, None),
                         'softplus':(nn.Softplus(), init_weights_normal, None),
                         'elu':(nn.ELU(inplace=True), init_weights_elu, None),
                         'swish':(Swish(), init_weights_xavier, None),
                         }
                        
        self.nl, nl_weight_init, first_layer_init = nls_and_inits[nonlinearity]

        if weight_init is not None:  # Overwrite weight init if passed
            self.weight_init = weight_init
        else:
            self.weight_init = nl_weight_init # those are default init funcs 

        self.net1 = nn.ModuleList([BatchLinear(in_coord_features,hidden_features)] + 
                                  [BatchLinear(hidden_features,hidden_features) for i in range(num_hidden_layers)] + 
                                  [BatchLinear(hidden_features,out_features)])
        self.net2 = nn.ModuleList([BatchLinear(in_latent_features,hidden_features,bias = False) for i in range(num_hidden_layers+1)])

        if self.weight_init is not None:
            self.net1.apply(self.weight_init)
            self.net2.apply(self.weight_init)

        if first_layer_init is not None: # Apply special initialization to first layer, if applicable.
            self.net1[0].apply(first_layer_init)
            self.net2[0].apply(first_layer_init)
        if bias_init is not None:
            self.net2.apply(bias_init)

    def forward(self, coords, latents):
        # coords: <t,h,w,c_coord>
        # latents: <t,h,w,c_latent>

        # premap 
        if not self.premap_mode ==None: 
            x = self.premap_layer(coords)
        else: 
            x = coords

        # pass it through  the nf network 
        for i in range(len(self.net1) -1):
            x = self.net1[i](x) + self.net2[i](latents)
            x = self.nl(x)
        x = self.net1[-1](x)
        return x 


    def disable_gradient(self):
        for param in self.parameters():
            param.requires_grad = False

# fn auto decoder: FILM
class FNAutodecoder_film(nn.Module):
    '''
    fourier network with author decoding 
    '''
    def __init__(self, in_coord_features, in_latent_features, out_features, num_hidden_layers, hidden_features,
                 bias=True, output_act=False,input_scale=256.0,weight_scale=1.0,alpha=6.0,beta=1.0
                 ,premap_mode = None, **kwargs):
        super().__init__()
        self.premap_mode = premap_mode
        if not self.premap_mode ==None: 
            self.premap_layer = FeatureMapping(in_coord_features,mode = premap_mode, **kwargs)
            in_coord_features = self.premap_layer.dim # update the nf in features     
            print('premap dim:',self.premap_layer.dim)
        # assemble the MFN
        self.net1 = nn.ModuleList(
            [nn.Linear(hidden_features, hidden_features, bias) for _ in range(num_hidden_layers)]+[nn.Linear(hidden_features, out_features)]
        )

        # assemble the FILM
        self.net2 = nn.ModuleList([nn.Linear(in_latent_features, hidden_features,bias = False ) for _ in range(num_hidden_layers+1)])

        # initialize MFN
        for lin in self.net1:
            in_dim = lin.weight.shape[1]
            lin.weight.data.uniform_(
                -np.sqrt(weight_scale / in_dim),
                np.sqrt(weight_scale / in_dim),
            )

        # initialize FILM
        for lin in self.net2:
            in_dim = lin.weight.shape[1]
            lin.weight.data.uniform_(
                -np.sqrt(weight_scale / in_dim),
                np.sqrt(weight_scale / in_dim),
            )

        # initialize filters 
        self.filters = nn.ModuleList(
            [
                FourierLayer(in_coord_features, hidden_features, input_scale / np.sqrt(num_hidden_layers + 1))
                for _ in range(num_hidden_layers + 1)
            ]
        )


    def forward(self, coords, latents):
        # coords: <t,h,w,c_coord>
        # latents: <t,h,w,c_latent>
        
        # premap 
        if not self.premap_mode ==None: 
            x0 = self.premap_layer(coords)
        else: 
            x0 = coords

        x = self.filters[0](x0)*(self.net2[0](latents))

        # pass it through  the nf network 
        for i in range(1,len(self.filters)):
            x = self.filters[i](x0)*(self.net1[i-1](x)+self.net2[i](latents))
        x = self.net1[-1](x)
        return x 

# auto decoder: FILM
class GNAutodecoder_film(nn.Module):
    '''
    siren network with author decoding 
    '''
    def __init__(self, in_coord_features, in_latent_features, out_features, num_hidden_layers, hidden_features,
                 bias=True, output_act=False,input_scale=256.0,weight_scale=1.0,alpha=6.0,beta=1.0
                 ,premap_mode = None, **kwargs):
        super().__init__()
        self.premap_mode = premap_mode
        if not self.premap_mode ==None: 
            self.premap_layer = FeatureMapping(in_coord_features,mode = premap_mode, **kwargs)
            in_coord_features = self.premap_layer.dim # update the nf in features     
            print('premap dim:',self.premap_layer.dim)
        # assemble the MFN
        self.net1 = nn.ModuleList(
            [nn.Linear(hidden_features, hidden_features, bias) for _ in range(num_hidden_layers)]+[nn.Linear(hidden_features, out_features)]
        )

        # assemble the FILM
        self.net2 = nn.ModuleList([nn.Linear(in_latent_features, hidden_features,bias = False ) for _ in range(num_hidden_layers+1)])

        # initialize MFN
        for lin in self.net1:
            in_dim = lin.weight.shape[1]
            lin.weight.data.uniform_(
                -np.sqrt(weight_scale / in_dim),
                np.sqrt(weight_scale / in_dim),
            )

        # initialize FILM
        for lin in self.net2:
            in_dim = lin.weight.shape[1]
            lin.weight.data.uniform_(
                -np.sqrt(weight_scale / in_dim),
                np.sqrt(weight_scale / in_dim),
            )

        # initialize filters 
        self.filters = nn.ModuleList(
            [
                GaborLayer(
                    in_coord_features,
                    hidden_features,
                    input_scale / np.sqrt(num_hidden_layers + 1),
                    alpha / (num_hidden_layers + 1),
                    beta,
                )
                for _ in range(num_hidden_layers + 1)
            ]
        )

    def forward(self, coords, latents):
        # coords: <t,h,w,c_coord>
        # latents: <t,h,w,c_latent>
        
        # premap 
        if not self.premap_mode ==None: 
            x0 = self.premap_layer(coords)
        else: 
            x0 = coords

        x = self.filters[0](x0)*(self.net2[0](latents))

        # pass it through  the nf network 
        for i in range(1,len(self.filters)):
            x = self.filters[i](x0)*(self.net1[i-1](x)+self.net2[i](latents))
        x = self.net1[-1](x)
        return x 


# siren auto decoder: modified FILM
class SIRENAutodecoder_fp(nn.Module):
    '''
    siren network with author decoding 
    mdf mean the conditioning is full projection 
    '''
    def __init__(self, in_coord_features, in_latent_features, out_features, num_hidden_layers, hidden_features,
                 outermost_linear=False, nonlinearity='sine', weight_init=None,bias_init=None,
                 premap_mode = None, **kwargs):
        super().__init__()
        self.premap_mode = premap_mode
        if not self.premap_mode ==None: 
            self.premap_layer = FeatureMapping(in_coord_features,mode = premap_mode, **kwargs)
            in_coord_features = self.premap_layer.dim # update the nf in features     

        self.first_layer_init = None
        # Dictionary that maps nonlinearity name to the respective function, initialization, and, if applicable,
        # special first-layer initialization scheme
        # different layers has different initialization schemes: 
        nls_and_inits = {'sine':(Sine(), sine_init, first_layer_sine_init), # act name, init func, first layer init func 
                         'relu':(nn.ReLU(inplace=True), init_weights_normal, None),
                         'sigmoid':(nn.Sigmoid(), init_weights_xavier, None),
                         'tanh':(nn.Tanh(), init_weights_xavier, None),
                         'selu':(nn.SELU(inplace=True), init_weights_selu, None),
                         'softplus':(nn.Softplus(), init_weights_normal, None),
                         'elu':(nn.ELU(inplace=True), init_weights_elu, None),
                         'swish':(Swish(), init_weights_xavier, None),
                         }
                        
        self.nl, nl_weight_init, first_layer_init = nls_and_inits[nonlinearity]

        if weight_init is not None:  # Overwrite weight init if passed
            self.weight_init = weight_init
        else:
            self.weight_init = nl_weight_init # those are default init funcs 

        # create the net for the nf 
        self.nf_net = nn.ModuleList([BatchLinear(in_coord_features,hidden_features)] + 
                                  [BatchLinear(hidden_features,hidden_features) for i in range(num_hidden_layers)] + 
                                  [BatchLinear(hidden_features,out_features)])

        # create the net for the weights and bias, the hypernet it self has no bias. 
        self.hw_net = nn.ModuleList([BatchLinear(in_latent_features,in_coord_features*hidden_features,bias = False)]+
                                    [BatchLinear(in_latent_features,hidden_features*hidden_features,bias = False) for i in range(num_hidden_layers)]+
                                    [BatchLinear(in_latent_features,hidden_features*out_features,bias = False)])
        self.hb_net = nn.ModuleList([BatchLinear(in_latent_features,hidden_features,bias = False) for i in range(num_hidden_layers+1)]+
                                     [BatchLinear(in_latent_features,out_features*out_features,bias = False)])

        if self.weight_init is not None:
            self.nf_net.apply(self.weight_init)

        if first_layer_init is not None: # Apply special initialization to first layer, if applicable.
            self.nf_net[0].apply(first_layer_init)
        self.init_hyper_layer()     

    def init_hyper_layer(self):
        # init weights
        self.hw_net.apply(init_weights_uniform_siren_scale)
        self.hb_net.apply(init_weights_uniform_siren_scale)

    def forward(self, coords, latents):
        t_size, h_size, w_size, coord_size = coords.shape
        # coords: <t,h,w,c_coord>
        # latents: <t,h,w,c_latent>
        
        # premap 
        if not self.premap_mode ==None: 
            x = self.premap_layer(coords)
        else: 
            x = coords

        # pass it through  the nf network 
        for i in range(len(self.nf_net) -1):
            x = self.nf_net[i](x) + torch.einsum('thwi,thwji->thwj', x, self.hw_net[i](latents).reshape((t_size, h_size, w_size)+self.nf_net[i].weight.shape))+ self.hb_net[i](latents)
            x = self.nl(x)
        x = self.nf_net[-1](x)
        return x 


# siren auto decoder: modified FILM
class SIRENAutodecoder_mdf_film(nn.Module):
    '''
    siren network with author decoding 
    mdf mean the conditioning is full projection 
    '''
    def __init__(self, in_coord_features, in_latent_features, out_features, num_hidden_layers, hidden_features,
                 outermost_linear=False, nonlinearity='sine', weight_init=None,bias_init=None,
                 premap_mode = None, **kwargs):
        super().__init__()
        self.premap_mode = premap_mode
        if not self.premap_mode ==None: 
            self.premap_layer = FeatureMapping(in_coord_features,mode = premap_mode, **kwargs)
            in_coord_features = self.premap_layer.dim # update the nf in features     

        self.first_layer_init = None
        # Dictionary that maps nonlinearity name to the respective function, initialization, and, if applicable,
        # special first-layer initialization scheme
        # different layers has different initialization schemes: 
        nls_and_inits = {'sine':(Sine(), sine_init, first_layer_sine_init), # act name, init func, first layer init func 
                         'relu':(nn.ReLU(inplace=True), init_weights_normal, None),
                         'sigmoid':(nn.Sigmoid(), init_weights_xavier, None),
                         'tanh':(nn.Tanh(), init_weights_xavier, None),
                         'selu':(nn.SELU(inplace=True), init_weights_selu, None),
                         'softplus':(nn.Softplus(), init_weights_normal, None),
                         'elu':(nn.ELU(inplace=True), init_weights_elu, None),
                         'swish':(Swish(), init_weights_xavier, None),
                         }
                        
        self.nl, nl_weight_init, first_layer_init = nls_and_inits[nonlinearity]

        if weight_init is not None:  # Overwrite weight init if passed
            self.weight_init = weight_init
        else:
            self.weight_init = nl_weight_init # those are default init funcs 

        # create the net for the nf 
        self.nf_net = nn.ModuleList([BatchLinear(in_coord_features,hidden_features)] + 
                                  [BatchLinear(hidden_features,hidden_features) for i in range(num_hidden_layers)] + 
                                  [BatchLinear(hidden_features,out_features)])

        # create the net for the weights and bias, the hypernet it self has no bias. 
        self.hw_net = nn.ModuleList([BatchLinear(in_latent_features,in_coord_features*hidden_features,bias = False)]+
                                    [BatchLinear(in_latent_features,hidden_features*hidden_features,bias = False) for i in range(num_hidden_layers)]+
                                    [BatchLinear(in_latent_features,hidden_features*out_features,bias = False)])
        self.hb_net = nn.ModuleList([BatchLinear(in_latent_features,hidden_features,bias = False) for i in range(num_hidden_layers+1)]+
                                     [BatchLinear(in_latent_features,out_features*out_features,bias = False)])

        if self.weight_init is not None:
            self.nf_net.apply(self.weight_init)

        if first_layer_init is not None: # Apply special initialization to first layer, if applicable.
            self.nf_net[0].apply(first_layer_init)
        self.init_hyper_layer()     

    def init_hyper_layer(self):
        # init weights
        self.hw_net.apply(init_weights_uniform_siren_scale)
        self.hb_net.apply(init_weights_uniform_siren_scale)

    def forward(self, coords, latents):
        t_size, h_size, w_size, coord_size = coords.shape
        # coords: <t,h,w,c_coord>
        # latents: <t,h,w,c_latent>
        
        # premap 
        if not self.premap_mode ==None: 
            x = self.premap_layer(coords)
        else: 
            x = coords

        # pass it through  the nf network 
        for i in range(len(self.nf_net) -1):
            x = self.nf_net[i](x) + torch.einsum('thwi,thwji->thwj', x, self.hw_net[i](latents).reshape((t_size, h_size, w_size)+self.nf_net[i].weight.shape))+ self.hb_net[i](latents)
            x = self.nl(x)
        x = self.nf_net[-1](x)
        return x 

# fn auto decoder: FILM
class FNAutodecoder_mdf_film(nn.Module):
    '''
    fourier network with author decoding 
    '''
    def __init__(self, in_coord_features, in_latent_features, out_features, num_hidden_layers, hidden_features,
                 bias=True, output_act=False,input_scale=256.0,weight_scale=1.0,alpha=6.0,beta=1.0
                 ,premap_mode = None, **kwargs):
        super().__init__()
        self.premap_mode = premap_mode
        if not self.premap_mode ==None: 
            self.premap_layer = FeatureMapping(in_coord_features,mode = premap_mode, **kwargs)
            in_coord_features = self.premap_layer.dim # update the nf in features     
            print('premap dim:',self.premap_layer.dim)
        # assemble the MFN
        self.net1 = nn.ModuleList(
            [nn.Linear(hidden_features, hidden_features, bias) for _ in range(num_hidden_layers)]+[nn.Linear(hidden_features, out_features)]
        )

        # assemble the FILM
        self.net2 = nn.ModuleList([nn.Linear(in_latent_features, hidden_features,bias = False ) for _ in range(num_hidden_layers+1)])

        # initialize MFN
        for lin in self.net1:
            in_dim = lin.weight.shape[1]
            lin.weight.data.uniform_(
                -np.sqrt(weight_scale / in_dim),
                np.sqrt(weight_scale / in_dim),
            )

        # initialize FILM
        for lin in self.net2:
            in_dim = lin.weight.shape[1]
            lin.weight.data.uniform_(
                -np.sqrt(weight_scale / in_dim),
                np.sqrt(weight_scale / in_dim),
            )

        # initialize filters 
        self.filters = nn.ModuleList(
            [
                FourierLayer(in_coord_features, hidden_features, input_scale / np.sqrt(num_hidden_layers + 1))
                for _ in range(num_hidden_layers + 1)
            ]
        )


    def forward(self, coords, latents):
        # coords: <t,h,w,c_coord>
        # latents: <t,h,w,c_latent>
        
        # premap 
        if not self.premap_mode ==None: 
            x0 = self.premap_layer(coords) 
        else: 
            x0 = coords

        x = self.filters[0](x0)*(self.net2[0](latents))

        # pass it through  the nf network 
        for i in range(1,len(self.filters)):
            x = self.filters[i](x0)*(self.net1[i-1](x)+self.net2[i](latents))
        x = self.net1[-1](x)
        return x 

# auto decoder: FILM
class GNAutodecoder_mdf_film(nn.Module):
    '''
    siren network with author decoding 
    '''
    def __init__(self, in_coord_features, in_latent_features, out_features, num_hidden_layers, hidden_features,
                 bias=True, output_act=False,input_scale=256.0,weight_scale=1.0,alpha=6.0,beta=1.0
                 ,premap_mode = None, **kwargs):
        super().__init__()
        self.premap_mode = premap_mode
        if not self.premap_mode ==None: 
            self.premap_layer = FeatureMapping(in_coord_features,mode = premap_mode, **kwargs)
            in_coord_features = self.premap_layer.dim # update the nf in features     
            print('premap dim:',self.premap_layer.dim)
        # assemble the MFN
        self.net1 = nn.ModuleList(
            [nn.Linear(hidden_features, hidden_features, bias) for _ in range(num_hidden_layers)]+[nn.Linear(hidden_features, out_features)]
        )

        # assemble the FILM
        self.net2 = nn.ModuleList([nn.Linear(in_latent_features, hidden_features,bias = False ) for _ in range(num_hidden_layers+1)])

        # initialize MFN
        for lin in self.net1:
            in_dim = lin.weight.shape[1]
            lin.weight.data.uniform_(
                -np.sqrt(weight_scale / in_dim),
                np.sqrt(weight_scale / in_dim),
            )

        # initialize FILM
        for lin in self.net2:
            in_dim = lin.weight.shape[1]
            lin.weight.data.uniform_(
                -np.sqrt(weight_scale / in_dim),
                np.sqrt(weight_scale / in_dim),
            )

        # initialize filters 
        self.filters = nn.ModuleList(
            [
                GaborLayer(
                    in_coord_features,
                    hidden_features,
                    input_scale / np.sqrt(num_hidden_layers + 1),
                    alpha / (num_hidden_layers + 1),
                    beta,
                )
                for _ in range(num_hidden_layers + 1)
            ]
        )

    def forward(self, coords, latents):
        # coords: <t,h,w,c_coord>
        # latents: <t,h,w,c_latent>
        
        # premap 
        if not self.premap_mode ==None: 
            x0 = self.premap_layer(coords)
        else: 
            x0 = coords

        x = self.filters[0](x0)*(self.net2[0](latents))

        # pass it through  the nf network 
        for i in range(1,len(self.filters)):
            x = self.filters[i](x0)*(self.net1[i-1](x)+self.net2[i](latents))
        x = self.net1[-1](x)
        return x 


# siren auto decoder: FILM
class SIRENAutodecoder_mdf2_film(nn.Module):
    '''
    siren network with author decoding 
    '''
    def __init__(self, in_coord_features, in_latent_features, out_features, num_hidden_layers, hidden_features,
                film_num_hidden_layers, film_hidden_features, film_nonlinearity='sine',
                 outermost_linear=False, nonlinearity='sine', weight_init=None,bias_init=None,
                 premap_mode = None, **kwargs):
        super().__init__()
        self.premap_mode = premap_mode
        if not self.premap_mode ==None: 
            self.premap_layer = FeatureMapping(in_coord_features,mode = premap_mode, **kwargs)
            in_coord_features = self.premap_layer.dim # update the nf in features     

        self.first_layer_init = None
        # Dictionary that maps nonlinearity name to the respective function, initialization, and, if applicable,
        # special first-layer initialization scheme
        # different layers has different initialization schemes: 
        nls_and_inits = {'sine':(Sine(), sine_init, first_layer_sine_init), # act name, init func, first layer init func 
                         'relu':(nn.ReLU(inplace=True), init_weights_normal, None),
                         'sigmoid':(nn.Sigmoid(), init_weights_xavier, None),
                         'tanh':(nn.Tanh(), init_weights_xavier, None),
                         'selu':(nn.SELU(inplace=True), init_weights_selu, None),
                         'softplus':(nn.Softplus(), init_weights_normal, None),
                         'elu':(nn.ELU(inplace=True), init_weights_elu, None),
                         'swish':(Swish(), init_weights_xavier, None),
                         }
                        
        self.nl, nl_weight_init, first_layer_init = nls_and_inits[nonlinearity]

        if weight_init is not None:  # Overwrite weight init if passed
            self.weight_init = weight_init
        else:
            self.weight_init = nl_weight_init # those are default init funcs 

        self.net1 = nn.ModuleList([BatchLinear(in_coord_features,hidden_features)] + 
                                  [BatchLinear(hidden_features,hidden_features) for i in range(num_hidden_layers)] + 
                                  [BatchLinear(hidden_features,out_features)])
        self.nested_net2 = nn.ModuleList([MLP(in_latent_features,hidden_features,film_num_hidden_layers, film_hidden_features) for i in range(num_hidden_layers+1)])

        if self.weight_init is not None:
            self.net1.apply(self.weight_init)
            # self.net2.apply(self.weight_init)
            for net in self.nested_net2:
                net.apply(self.weight_init)

        if first_layer_init is not None: # Apply special initialization to first layer, if applicable.
            self.net1[0].apply(first_layer_init)
            for net in self.nested_net2:
                net[0].apply(first_layer_init)
            # self.net2[0].apply(first_layer_init)
        if bias_init is not None:
            for net in self.nested_net2:
                net[0].apply(bias_init)
            # self.net2.apply(bias_init)

    def forward(self, coords, latents):
        # coords: <t,h,w,c_coord>
        # latents: <t,h,w,c_latent>
        
        # premap 
        if not self.premap_mode ==None: 
            x = self.premap_layer(coords)
        else: 
            x = coords

        # pass it through  the nf network 
        for i in range(len(self.net1) -1):
            x = self.net1[i](x) + self.nested_net2[i](latents)
            x = self.nl(x)
        x = self.net1[-1](x)
        return x 

# fn auto decoder: FILM
class FNAutodecoder_film(nn.Module):
    '''
    fourier network with author decoding 
    '''
    def __init__(self, in_coord_features, in_latent_features, out_features, num_hidden_layers, hidden_features,
                 bias=True, output_act=False,input_scale=256.0,weight_scale=1.0,alpha=6.0,beta=1.0
                 ,premap_mode = None, **kwargs):
        super().__init__()
        self.premap_mode = premap_mode
        if not self.premap_mode ==None: 
            self.premap_layer = FeatureMapping(in_coord_features,mode = premap_mode, **kwargs)
            in_coord_features = self.premap_layer.dim # update the nf in features     
            print('premap dim:',self.premap_layer.dim)
        # assemble the MFN
        self.net1 = nn.ModuleList(
            [nn.Linear(hidden_features, hidden_features, bias) for _ in range(num_hidden_layers)]+[nn.Linear(hidden_features, out_features)]
        )

        # assemble the FILM
        self.net2 = nn.ModuleList([nn.Linear(in_latent_features, hidden_features,bias = False ) for _ in range(num_hidden_layers+1)])

        # initialize MFN
        for lin in self.net1:
            in_dim = lin.weight.shape[1]
            lin.weight.data.uniform_(
                -np.sqrt(weight_scale / in_dim),
                np.sqrt(weight_scale / in_dim),
            )

        # initialize FILM
        for lin in self.net2:
            in_dim = lin.weight.shape[1]
            lin.weight.data.uniform_(
                -np.sqrt(weight_scale / in_dim),
                np.sqrt(weight_scale / in_dim),
            )

        # initialize filters 
        self.filters = nn.ModuleList(
            [
                FourierLayer(in_coord_features, hidden_features, input_scale / np.sqrt(num_hidden_layers + 1))
                for _ in range(num_hidden_layers + 1)
            ]
        )


    def forward(self, coords, latents):
        # coords: <t,h,w,c_coord>
        # latents: <t,h,w,c_latent>
        
        # premap 
        if not self.premap_mode ==None: 
            x0 = self.premap_layer(coords)
        else: 
            x0 = coords

        x = self.filters[0](x0)*(self.net2[0](latents))

        # pass it through  the nf network 
        for i in range(1,len(self.filters)):
            x = self.filters[i](x0)*(self.net1[i-1](x)+self.net2[i](latents))
        x = self.net1[-1](x)
        return x 

# auto decoder: FILM
class GNAutodecoder_film(nn.Module):
    '''
    siren network with author decoding 
    '''
    def __init__(self, in_coord_features, in_latent_features, out_features, num_hidden_layers, hidden_features,
                 bias=True, output_act=False,input_scale=256.0,weight_scale=1.0,alpha=6.0,beta=1.0
                 ,premap_mode = None, **kwargs):
        super().__init__()
        self.premap_mode = premap_mode
        if not self.premap_mode ==None: 
            self.premap_layer = FeatureMapping(in_coord_features,mode = premap_mode, **kwargs)
            in_coord_features = self.premap_layer.dim # update the nf in features     
            print('premap dim:',self.premap_layer.dim)
        # assemble the MFN
        self.net1 = nn.ModuleList(
            [nn.Linear(hidden_features, hidden_features, bias) for _ in range(num_hidden_layers)]+[nn.Linear(hidden_features, out_features)]
        )

        # assemble the FILM
        self.net2 = nn.ModuleList([nn.Linear(in_latent_features, hidden_features,bias = False ) for _ in range(num_hidden_layers+1)])

        # initialize MFN
        for lin in self.net1:
            in_dim = lin.weight.shape[1]
            lin.weight.data.uniform_(
                -np.sqrt(weight_scale / in_dim),
                np.sqrt(weight_scale / in_dim),
            )

        # initialize FILM
        for lin in self.net2:
            in_dim = lin.weight.shape[1]
            lin.weight.data.uniform_(
                -np.sqrt(weight_scale / in_dim),
                np.sqrt(weight_scale / in_dim),
            )

        # initialize filters 
        self.filters = nn.ModuleList(
            [
                GaborLayer(
                    in_coord_features,
                    hidden_features,
                    input_scale / np.sqrt(num_hidden_layers + 1),
                    alpha / (num_hidden_layers + 1),
                    beta,
                )
                for _ in range(num_hidden_layers + 1)
            ]
        )

    def forward(self, coords, latents):
        # coords: <t,h,w,c_coord>
        # latents: <t,h,w,c_latent>
        
        # premap 
        if not self.premap_mode ==None: 
            x0 = self.premap_layer(coords)
        else: 
            x0 = coords

        x = self.filters[0](x0)*(self.net2[0](latents))

        # pass it through  the nf network 
        for i in range(1,len(self.filters)):
            x = self.filters[i](x0)*(self.net1[i-1](x)+self.net2[i](latents))
        x = self.net1[-1](x)
        return x 


import math
class PosEncodingNeRF(nn.Module):
    '''Module to add positional encoding as in NeRF [Mildenhall et al. 2020].'''
    def __init__(self, in_features, sidelength=None, fn_samples=None, use_nyquist=True):
        super().__init__()

        self.in_features = in_features
        # print('popopopo',self.in_features)
        if self.in_features == 3:
            self.num_frequencies = 10
        elif self.in_features == 2:
            assert sidelength is not None
            if isinstance(sidelength, int):
                sidelength = (sidelength, sidelength)
            self.num_frequencies = 4
            if use_nyquist:
                self.num_frequencies = self.get_num_frequencies_nyquist(min(sidelength[0], sidelength[1]))
        elif self.in_features == 1:
            assert fn_samples is not None
            self.num_frequencies = 4
            if use_nyquist:
                self.num_frequencies = self.get_num_frequencies_nyquist(fn_samples)

        self.out_dim = in_features + 2 * in_features * self.num_frequencies

    def get_num_frequencies_nyquist(self, samples):
        nyquist_rate = 1 / (2 * (2 * 1 / samples))
        return int(math.floor(math.log(nyquist_rate, 2)))

    def forward(self, coords):
        coords = coords.view(coords.shape[0], -1, self.in_features)

        coords_pos_enc = coords
        for i in range(self.num_frequencies):
            for j in range(self.in_features):
                c = coords[..., j]

                sin = torch.unsqueeze(torch.sin((2 ** i) * np.pi * c), -1)
                cos = torch.unsqueeze(torch.cos((2 ** i) * np.pi * c), -1)

                coords_pos_enc = torch.cat((coords_pos_enc, sin, cos), axis=-1)

        return coords_pos_enc.reshape(coords.shape[0], -1, self.out_dim)




# auto decoder: bias shift 
# auto decoder: weight and shift bias 
# auto decoder: filter


########################
# Initialization methods
########################
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    # For PINNet, Raissi et al. 2019
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    # grab from upstream pytorch branch and paste here for now
    def norm_cdf(x):
        # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    with torch.no_grad():
        # Values are generated by using a truncated uniform distribution and
        # then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)

        # Uniformly fill tensor with values from [l, u], then translate to
        # [2l-1, 2u-1].
        tensor.uniform_(2 * l - 1, 2 * u - 1)

        # Use inverse cdf transform for normal distribution to get truncated
        # standard normal
        tensor.erfinv_()

        # Transform to proper mean, std
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)

        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor


def init_weights_trunc_normal(m):
    # For PINNet, Raissi et al. 2019
    # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    if type(m) == BatchLinear or type(m) == nn.Linear:
        if hasattr(m, 'weight'):
            fan_in = m.weight.size(1)
            fan_out = m.weight.size(0)
            std = math.sqrt(2.0 / float(fan_in + fan_out))
            mean = 0.
            # initialize with the same behavior as tf.truncated_normal
            # "The generated values follow a normal distribution with specified mean and
            # standard deviation, except that values whose magnitude is more than 2
            # standard deviations from the mean are dropped and re-picked."
            _no_grad_trunc_normal_(m.weight, mean, std, -2 * std, 2 * std)

def init_weights_uniform(m):
    if type(m) == BatchLinear or type(m) == nn.Linear:
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            nn.init.uniform_(m.weight,-1 / num_input, 1 / num_input)

def init_weights_uniform_mfn(m, weight_scale= 1.):
    if type(m) == BatchLinear or type(m) == nn.Linear:
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            nn.init.uniform_(m.weight,-math.sqrt(weight_scale / num_input), math.sqrt(weight_scale / num_input))


def init_weights_uniform_siren_scale(m, scale= 1e-2):
    if type(m) == BatchLinear or type(m) == nn.Linear:
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            nn.init.uniform_(m.weight,-math.sqrt(6/ num_input)*scale, math.sqrt(6 / num_input)*scale)


def init_weights_normal(m):
    if type(m) == BatchLinear or type(m) == nn.Linear:
        if hasattr(m, 'weight'):
            # print('applying normal init')
            nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')


def init_weights_selu(m):
    if type(m) == BatchLinear or type(m) == nn.Linear:
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            nn.init.normal_(m.weight, std=1 / math.sqrt(num_input))


def init_weights_elu(m):
    if type(m) == BatchLinear or type(m) == nn.Linear:
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            nn.init.normal_(m.weight, std=math.sqrt(1.5505188080679277) / math.sqrt(num_input))


def init_weights_xavier(m):
    if type(m) == BatchLinear or type(m) == nn.Linear:
        if hasattr(m, 'weight'):
            nn.init.xavier_normal_(m.weight)


def sine_init(m, w0 = 30):
    with torch.no_grad():
        if hasattr(m, 'weight'):
            # print('applying normal init')
            num_input = m.weight.size(-1)
            # See supplement Sec. 1.5 for discussion of factor 30
            m.weight.uniform_(-np.sqrt(6 / num_input) / w0, np.sqrt(6 / num_input) / w0)

def first_layer_sine_init(m):
    with torch.no_grad():
        if hasattr(m, 'weight'):
            num_input = m.weight.size(-1)
            # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30
            m.weight.uniform_(-1 / num_input, 1 / num_input)

def init_bias_uniform(m):
    with torch.no_grad():
        if hasattr(m, 'bias'):
            num_input = m.weights.size(-1)
            m.bias.uniform_(-1 / num_input, 1 / num_input)

def init_bias_uniform_sqrt(m):
    with torch.no_grad():
        if hasattr(m, 'bias'):
            num_input = m.weights.size(-1)
            m.bias.uniform_(-1 / math.sqrt(num_input), 1 /  math.sqrt(num_input))

            
#####
#meta learning 
#####
    
def GetSubDict(net):
    params = {}
    print('find all params')
    for name, param in net.named_parameters():
        print(name, param.data.shape)
        params[name] = param.data
    print('find all params done!')
    return params 


def GetSubDict_FILM(net):
    params1 = {}
    params2 = {}
    print('find all params')
    for name, param in net.named_parameters():
        print(name, param.data.shape)
        if 'net1' in name: 
            params1[name] = param.data
        elif 'net2' in name: 
            params2[name] = param.data
    print('find all params done!')
    return params 

def ExtractSubDict(my_dict, name):
    new_dict = {}
    for key in my_dict.keys():
        if name in key:
            new_dict[key] = my_dict[key]
    return new_dict


def GetNumParam(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
        