import numpy as np
import torch
import os
import timeit
import copy
import argparse
import anndata
from anndata import AnnData
import scanpy as sc
from sklearn.neighbors import kneighbors_graph
import gc
from typing import Union, Tuple
import pickle
import warnings
import logging
warnings.filterwarnings("ignore")
import warnings
import logging
import pytorch_lightning as pl
#logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
logging.getLogger("pytorch_lightning").setLevel(logging.CRITICAL)
from .vaegan import *
def setdata(name:str,sid:str,device:str='cuda:0',k:int=15,diagw:float=1.0) -> Tuple[anndata.AnnData, torch.Tensor, torch.Tensor, torch.Tensor, int]:
"""
Wrapper function for any preparations that need to be done for a anndata.AnnData object before sending to the model.
Parameters
----------
name
The porject name.
sid
Sample ID of the sample to be prepared
device
CPU or GPU for model training
k
Number of neighbors to consider in the cell graph
diagw
The weight of the original cell when agregating the information
Returns
-------
adata
The augmented anndata object.
adj
The adjacency matrix of the cell neighbor graph.
variances
Variances of the features
pseudobulk
pseudobulk data of the sample
geneset_len
Length of the gene set score features
"""
adata = anndata.read_h5ad(name + '/sample_sc/' + sid + '.h5ad')
# load geneset
if 'geneset_scores' in os.listdir(name):
sample_geneset = np.load(name + '/geneset_scores/'+sid+'.npy')
setmask = np.load(name + '/hvset.npy')
sample_geneset = sample_geneset[:,setmask]
sample_geneset = sample_geneset.astype('float32')
geneset_len = sample_geneset.shape[1]
features = np.concatenate([adata.X,sample_geneset],1)
bdata = anndata.AnnData(features,dtype='float32')
bdata.obs = adata.obs
bdata.obsm = adata.obsm
bdata.uns = adata.uns
adata = bdata.copy()
else:
geneset_len = 0
# adj for cell graph
adj = adata.obsm['adj']
adj = torch.from_numpy(adj.astype('float32'))
# variances
variances = torch.tensor(adata.uns['feature_var'])
variances = variances.to(device)
#pseudobulk
pseudobulk = np.array(adata.X.mean(axis=0)).reshape((-1))
fastgenerator.setup_anndata(adata)
return adata,adj,variances,pseudobulk,geneset_len
def fastrecon(name:str, sid:str, device:str='cuda:0',k:int=15,diagw:float=1.0,vaesteps:int=100,gansteps:int=100,lr:float = 1e-3,save:bool=True,path:str=None) -> fastgenerator:
"""
Accelerated version of pretrain 1 reconstruction.
Parameters
----------
name
The porject name.
sid
Sample ID of the sample to be prepared
device
CPU or GPU for model training
k
Number of neighbors to consider in the cell graph
diagw
The weight of the original cell when agregating the information
vaestep:
Steps for training the generator
ganstep:
Steps for joint training the generator and the discriminator.
lr
Learning rate
save
Saving the model or not
path
Path for saving the model
Returns
-------
model
The trained model.
"""
#set data
adata,adj,variances,bulk,geneset_len = setdata(name,sid,device,k,diagw)
#print(0)
#print(variances.shape)
#print(adata)
# train
model = fastgenerator(variances,None,geneset_len,adata,n_hidden=256,n_latent=32,dropout_rate=0)
model.train(max_epochs=vaesteps, plan_kwargs={'lr':lr,'lr2':0,'kappa':4.0},use_gpu=device)
model.train(max_epochs=gansteps*3, plan_kwargs={'lr':lr,'lr2':lr,'kappa':4.0},use_gpu=device)
# save model
if save == True:
if path == None:
if (os.path.isdir(name + '/models')) == False:
os.system('mkdir '+ name + '/models')
path = name + '/models/fast_reconst1_'+sid
torch.save(model.module.state_dict(), path)
with open(name+'/history/pretrain1_' + sid + '.pkl', 'wb') as pickle_file:
pickle.dump(model.history, pickle_file)
return model
# reconst stage 2
def reconst_pretrain2(name:str, sid:str ,premodel:Union[str,fastgenerator],device='cuda:0',k=15,diagw=1.0,vaesteps=50,gansteps=50,lr=1e-4,save=True,path=None)->fastgenerator:
"""
Accelerated version of pretrain 2 reconstruction.
Parameters
----------
name
The porject name.
sid
Sample ID of the sample to be prepared
premodel
Pretrained model or path to pretrained model
device
CPU or GPU for model training
k
Number of neighbors to consider in the cell graph
diagw
The weight of the original cell when agregating the information
vaestep:
Steps for training the generator
ganstep:
Steps for joint training the generator and the discriminator.
lr
Learning rate
save
Saving the model or not
path
Path for saving the model
Returns
-------
model
The trained model.
"""
adata,adj,variances,bulk,geneset_len = setdata(name,sid,device,k,diagw)
#(4) bulk
bulk = (np.array(adata.X)).mean(axis=0)
bulk = bulk.reshape((-1))
bulk = torch.tensor(bulk).to(device)
#(5) reconstruct pretrain
fastgenerator.setup_anndata(adata)
model = fastgenerator(variances = variances,bulk=bulk,geneset_len = geneset_len,adata=adata,\
n_hidden=256,n_latent=32,dropout_rate=0,countbulkweight=1,logbulkweight=0,absbulkweight=0,abslogbulkweight=0,\
power=2,corrbulkweight=0)
if type(premodel) == type(None):
pass
else:
model.module.load_state_dict(premodel.module.state_dict())
batch_size = adata.X.shape[0]
model.train(max_epochs=vaesteps, plan_kwargs={'lr':lr,'lr2':0,'kappa':40.0},use_gpu=device)
model.train(max_epochs=gansteps*3, plan_kwargs={'lr':lr,'lr2':lr,'kappa':40.0},use_gpu=device)
if save == True:
if path == None:
path = name + '/models/fastreconst2_' + sid
torch.save(model.module.state_dict(), path)
with open(name+'/history/pretrain2_' + sid + '.pkl', 'wb') as pickle_file:
pickle.dump(model.history, pickle_file)
return model
def unisemi0(name,adata,adj,variances,geneset_len,bulk,batch_size,reprepid,tgtpid,premodel,device='cuda:5',k=15,diagw=1.0,lr = 2e-4,epochs=150):
model0 = fastgenerator(adata=adata,variances=variances,geneset_len=geneset_len,\
bulk=bulk,n_hidden=256,n_latent=32,\
dropout_rate=0,countbulkweight =1*8,logbulkweight=0,absbulkweight=0,abslogbulkweight=0,corrbulkweight=0,\
power=2,upperbound=99999)
if type(premodel)==type('string'):
model0.module.load_state_dict(torch.load(premodel))
else:
model0.module.load_state_dict(premodel.module.state_dict())
model0.train(max_epochs=epochs, plan_kwargs={'lr':lr,'lr2':1e-10,'kappa':4040*1e-10},use_gpu=device,batch_size=batch_size)
torch.save(model0.module.state_dict(), name+'/tmp/model0')
return model0.history
def unisemi1(name,adata,adj,variances,geneset_len,bulk,batch_size,upperbound,reprepid,tgtpid,premodel,device='cuda:5',k=15,diagw=1.0,lr = 2e-4,epochs=150):
model1 = fastgenerator(adata=adata,variances=variances,geneset_len=geneset_len,\
bulk=bulk,n_hidden=256,n_latent=32,\
dropout_rate=0,countbulkweight = 4*8,logbulkweight=0,absbulkweight=0,abslogbulkweight=0,corrbulkweight=0,\
power=2,upperbound=upperbound)
model1.module.load_state_dict(torch.load(name+'/tmp/model0'))
model1.train(max_epochs=epochs, plan_kwargs={'lr':lr,'lr2':1e-10,'kappa':4040*1e-10},use_gpu=device,batch_size=batch_size)
torch.save(model1.module.state_dict(), name+'/tmp/model1')
return model1.history
def unisemi2(name,adata,adj,variances,geneset_len,bulk,batch_size,upperbound,reprepid,tgtpid,premodel,device='cuda:5',k=15,diagw=1.0,lr = 2e-4,epochs=150):
model2 = fastgenerator(adata=adata,variances=variances,geneset_len=geneset_len,\
bulk=bulk,n_hidden=256,n_latent=32,\
dropout_rate=0,countbulkweight = 16*8,logbulkweight=0,absbulkweight=0,abslogbulkweight=0,corrbulkweight=0,\
power=2,upperbound=upperbound)
model2.module.load_state_dict(torch.load(name+'/tmp/model1'))
model2.train(max_epochs=epochs, plan_kwargs={'lr':lr,'lr2':1e-10,'kappa':4040*1e-10},use_gpu=device,batch_size=batch_size)
torch.save(model2.module.state_dict(), name+'/tmp/model2')
return model2.history
def unisemi3(name,adata,adj,variances,geneset_len,bulk,batch_size,upperbound,reprepid,tgtpid,premodel,device='cuda:5',k=15,diagw=1.0,lr = 2e-4,epochs=150):
model3 = fastgenerator(adata=adata,variances=variances,geneset_len=geneset_len,\
bulk=bulk,n_hidden=256,n_latent=32,\
dropout_rate=0,countbulkweight = 64*8,logbulkweight=0,absbulkweight=0,abslogbulkweight=0,corrbulkweight=0,\
power=2,upperbound=upperbound)
model3.module.load_state_dict(torch.load(name+'/tmp/model2'))
model3.train(max_epochs=epochs, plan_kwargs={'lr':lr,'lr2':1e-10,'kappa':4040*1e-10},use_gpu=device,batch_size=batch_size)
torch.save(model3.module.state_dict(), name+'/tmp/model3')
return model3.history
def unisemi4(name,adata,adj,variances,geneset_len,bulk,batch_size,upperbound,reprepid,tgtpid,premodel,device='cuda:5',k=15,diagw=1.0,lr = 2e-4,epochs=150):
model4 = fastgenerator(adata=adata,variances=variances,geneset_len=geneset_len,\
bulk=bulk,n_hidden=256,n_latent=32,\
dropout_rate=0,countbulkweight = 128*8,logbulkweight=0,absbulkweight=0,abslogbulkweight=0,corrbulkweight=0,\
power=2,upperbound=upperbound)
model4.module.load_state_dict(torch.load(name+'/tmp/model3'))
model4.train(max_epochs=epochs, plan_kwargs={'lr':lr,'lr2':1e-10,'kappa':4040*1e-10},use_gpu=device,batch_size=batch_size)
torch.save(model4.module.state_dict(), name+'/tmp/model4')
return model4.history
def unisemi5(adata,adj,variances,geneset_len,bulk,batch_size,upperbound,reprepid,tgtpid,premodel,device='cuda:5',k=15,diagw=1.0,lr = 2e-4,epochs=150):
model = fastgenerator(adata=adata,variances=variances,geneset_len=geneset_len,\
bulk=bulk,n_hidden=256,n_latent=32,\
dropout_rate=0,countbulkweight = 512*8,logbulkweight=0,absbulkweight=0,abslogbulkweight=0,corrbulkweight=0,\
power=2,upperbound=upperbound)
model.module.load_state_dict(torch.load(name+'/tmp/model4'))
model.train(max_epochs=epochs, plan_kwargs={'lr':lr,'lr2':1e-10,'kappa':4040*1e-10},use_gpu=device,batch_size=batch_size)
torch.save(model.module.state_dict(), name+'/tmp/model')
return model.history
def fast_semi(name:str,reprepid:int,tgtpid:int,premodel:Union[fastgenerator,str],device:str='cuda:0',k:int=15,diagw:float=1.0,bulktype='pseudobulk',pseudocount=0.1,lr = 2e-4,epochs=150,ministages = 5) -> Tuple[ dict, np.array, fastgenerator]:
"""
Accelerated version of single-cell inference for the target sample.
Parameters
----------
name
The porject name.
reprepid
Sample ID (number) of the representative
tgtpid
Sample ID (number) of the target sample
premodel
Pretrained model or path to pretrained model
device
CPU or GPU for model training
k
Number of neighbors to consider in the cell graph
diagw
The weight of the original cell when agregating the information
bulktype
'real' or 'pseudobulk'
pseudocount
Pseudocount value used when simulating pseudobulk using real bulk
lr
Learning rate
epochs
Epochs for each mini-stage in the inference process
Returns
-------
histdic
A dictionary containing the training history information
xsemi
Inferred single-cell data
model
Trained inference model
"""
sids = []
f = open(name + '/sids.txt','r')
lines = f.readlines()
for l in lines:
sids.append(l.strip())
f.close()
adata,adj,variances,reprepseudobulk,geneset_len = setdata(name,sids[reprepid],device=device,k=k,diagw=diagw)
varainces = None
maxexpr = adata.X.max()
upperbounds = [maxexpr/2, maxexpr/4, maxexpr/8, maxexpr/(8*np.sqrt(2)),maxexpr/16, maxexpr/32,maxexpr/64]
genelen = len(np.load(name+'/hvgenes.npy',allow_pickle=True))
#(5) tgt bulk
if bulktype == 'real':
tgtbulkdata = anndata.read_h5ad(name + '/processed_bulkdata.h5ad')
tgtbulk = np.exp(tgtbulkdata.X[tgtpid]) - 1
repbulk = np.exp(tgtbulkdata.X[reprepid]) - 1
tgtrealbulk = np.array(tgtbulk).reshape((1,-1)) # target real bulk
reprealbulk = np.array(repbulk).reshape((1,-1)) # representative real bulk
pseudobulk = (np.array(adata.X)).mean(axis=0)
pseudobulk = pseudobulk.reshape((1,-1))
ratio = np.array((tgtrealbulk+pseudocount)/(reprealbulk+pseudocount))
#ratio = np.array((tgtrealbulk+1)/(reprealbulk+1))
ratio = np.concatenate([ratio,np.ones((1, pseudobulk.shape[1]-ratio.shape[1]))],axis=1)
bulk = pseudobulk * ratio
bulk = torch.tensor(bulk).to(device)
elif (bulktype == 'pseudobulk') or (bulktype == 'pseudo'):
bulkdata = anndata.read_h5ad(name + '/processed_bulkdata.h5ad')
tgtbulk = np.exp(bulkdata.X[tgtpid]) - 1
tgtbulk = np.array(tgtbulk).reshape((1,-1))
bulk = adata.X.mean(axis=0)
bulk = np.array(bulk).reshape((1,-1))
bulk[:,:tgtbulk.shape[1]] = tgtbulk
bulk = torch.tensor(bulk).to(device)
else:
print('Error. Please specify bulktype as "pseudobulk" or "real".')
return
batch_size=int(np.min([adata.X.shape[0],9000]))
#(6) semiprofiling
fastgenerator.setup_anndata(adata)
hist = unisemi0(name,adata,adj,variances,geneset_len,bulk,batch_size,reprepid,tgtpid,premodel,device=device,k=k,diagw=1.0,lr=lr,epochs=epochs)
histdic={}
histdic['total0'] = hist['train_loss_epoch']
histdic['bulk0'] = hist['kl_global_train']
#del premodel
gc.collect()
torch.cuda.empty_cache()
#import time
#time.sleep(10)
hist = unisemi1(name,adata,adj,variances,geneset_len,bulk,batch_size,upperbounds[0],reprepid,tgtpid,premodel,device=device,k=k,diagw=1.0,lr=lr,epochs=epochs)
histdic['total1'] = hist['train_loss_epoch']
histdic['bulk1'] = hist['kl_global_train']
#del model0
gc.collect()
torch.cuda.empty_cache()
hist = unisemi2(name,adata,adj,variances,geneset_len,bulk,batch_size,upperbounds[1],reprepid,tgtpid,premodel,device=device,k=k,diagw=1.0,lr=lr,epochs=epochs)
histdic['total2'] = hist['train_loss_epoch']
histdic['bulk2'] = hist['kl_global_train']
#del model1
gc.collect()
#time.sleep(10)
torch.cuda.empty_cache()
if ministages>3:
hist = unisemi3(name,adata,adj,variances,geneset_len,bulk,batch_size,upperbounds[2],reprepid,tgtpid,premodel,device=device,k=k,diagw=1.0,lr=lr,epochs=epochs)
histdic['total3'] = hist['train_loss_epoch']
histdic['bulk3'] = hist['kl_global_train']
#del model2
gc.collect()
torch.cuda.empty_cache()
#time.sleep(10)
if ministages >4:
hist = unisemi4(name,adata,adj,variances,geneset_len,bulk,batch_size,upperbounds[3],reprepid,tgtpid,premodel,device=device,k=k,diagw=1.0,lr=lr,epochs=epochs)
histdic['total4'] = hist['train_loss_epoch']
histdic['bulk4'] = hist['kl_global_train']
#del model3
gc.collect()
torch.cuda.empty_cache()
#time.sleep(10)
#hist = unisemi5(adata,adj,variances,geneset_len,bulk,batch_size,upperbounds[4],reprepid,tgtpid,premodel,device=device,k=15,diagw=1.0)
#histdic['total'] = hist['train_loss_epoch']
#histdic['bulk'] = hist['kl_global_train']
model = fastgenerator(adata=adata,variances=variances,geneset_len=geneset_len,\
bulk=bulk,n_hidden=256,n_latent=32,\
dropout_rate=0,countbulkweight = 512,logbulkweight=0,absbulkweight=0,abslogbulkweight=0,corrbulkweight=0,\
power=2,upperbound=upperbounds[3])
model.module.load_state_dict(torch.load(name+'/tmp/model'+str(ministages-1)))
# inference
xsemi = []
scdl = model._make_data_loader(
adata=adata,batch_size=batch_size
)
for tensors in scdl:
samples = model.module.sample(tensors, n_samples=1)
xsemi.append(samples)
# save inferred data
xsemi = np.array(torch.cat(xsemi))[:,:genelen]
torch.save(model.module.state_dict(), name+'/models/semi_'+sids[reprepid]+"_to_"+sids[tgtpid])
xsemi = xsemi*(xsemi>10)
np.save(name + '/inferreddata/'+ sids[reprepid]+'_to_'+sids[tgtpid],xsemi)
# save training history
with open(name+'/history/inference_' + sids[reprepid] + '_to_' + sids[tgtpid] + '.pkl', 'wb') as pickle_file:
pickle.dump(histdic, pickle_file)
gc.collect()
torch.cuda.empty_cache()
return histdic,xsemi,model
[docs]def tgtinfer(name:str, representative:Union[str,int],target:Union[str,int],bulktype:str='pseudobulk',
lambdad:float = 4.0,
pretrain1batch:int = 128,
pretrain1lr:float = 1e-3,
pretrain1vae:int = 100,
pretrain1gan:int = 100,
lambdabulkr:float = 1,
pretrain2lr:float = 1e-4,
pretrain2vae:int = 50,
pretrain2gan:int = 50,
inferepochs:int = 150,
lambdabulkt:float = 8.0,
inferlr:float = 2e-4,
pseudocount:float = 0.1,
k:int = 15,
device:str = 'cuda:0') -> None:
"""
Computationally infer the single-cell data of a single non-representative target sample based on a representatives' single-cell data and bulk data of both samples.
Parameters
----------
name
The project name.
representative
The representative. Either indicated using sample ID (str) or the i-th (int) sample.
target
The target sample. Either indicated using sample ID (str) or the i-th (int) sample.
bulktype
Pseudobulk or real bulk data
lambdad
Scaling factor for the discriminator loss.
pretrain1batch
The mini-batch size during the first pretrain stage.
pretrain1lr
The learning rate used in the first pretrain stage.
pretrain1vae
The number of epochs for training the VAE during the first pretrain stage.
pretrain1gan
The number of iterations for training GAN during the first pretrain stage.
lambdabulkr
Scaling factor for represenatative bulk loss for pretrain 2.
pretrain2lr
Pretrain 2 learning rate.
pretrain2vae
The number of epochs for training the VAE during the second pretrain stage.
pretrain2gan
The number of iterations for training the GAN during the second pretrain stage.
inferepochs
The number of epochs used for each mini-stage during inference.
lambdabulkt
Scaling factor for the initial target bulk loss.
inferlr
Infer stage learning rate.
k
The number of nearest neighbors used in cell graph.
device
Which device to use, e.g. 'cpu', 'cuda:0'.
pseudocount
Pseudocount used when converting real bulk to pseudobulk space
Returns
-------
None
Example
-------
>>> name = 'project_name'
>>> scSemiProfiler.tgtinfer(name = name, representatives = 6, target = 7, bulktype = 'real')
"""
if (os.path.isdir(name + '/inferreddata')) == False:
os.system('mkdir ' + name + '/inferreddata')
if (os.path.isdir(name + '/models')) == False:
os.system('mkdir ' + name + '/models')
if (os.path.isdir(name + '/tmp')) == False:
os.system('mkdir ' + name + '/tmp')
if (os.path.isdir(name + '/history')) == False:
os.system('mkdir '+ name + '/history')
device = device
diagw = 1.0
sids = []
f = open(name + '/sids.txt','r')
lines = f.readlines()
for l in lines:
sids.append(l.strip())
f.close()
if type(representative) == type(123):
rp = sids[representative]
else:
rp = representative
if type(target) == type(123):
tgt = sids[target]
else:
tgt = target
tgtpid = sids.index(tgt)
reprepid = sids.index(rp)
print('pretrain 1: representative reconstruction')
# if exists, load model
modelfile = 'fast_reconst1_' + rp
path = name + '/models/fast_reconst1_' + rp
if modelfile in os.listdir(name + '/models'):
print('load existing pretrain 1 reconstruction model for ' + rp)
adata,adj,variances,bulk,geneset_len = setdata(name,rp,device,k,diagw)
model = fastgenerator(variances,None,geneset_len,adata,n_hidden=256,n_latent=32,dropout_rate=0)
model.module.load_state_dict(torch.load(path))
repremodel = model
#continue
else:
# otherwise, train model
repremodel = fastrecon(name=name,sid=rp,device=device,k=15,diagw=1,vaesteps=int(pretrain1vae),gansteps=int(pretrain1gan),save=True,path=None)\
print('pretrain2: reconstruction with representative bulk loss')
modelfile = 'fastreconst2_' + rp
path = name + '/models/fastreconst2_' + rp
if modelfile in os.listdir(name + '/models'):
print('load existing pretrain 2 model for ' + rp)
adata,adj,variances,bulk,geneset_len = setdata(name,rp,device,k,diagw)
model = fastgenerator(variances,None,geneset_len,adata,n_hidden=256,n_latent=32,dropout_rate=0)
model.module.load_state_dict(torch.load(path))
repremodels2 = model
else:
repremodels2 = (reconst_pretrain2(name,rp,repremodel,device,k=15,diagw=1.0,vaesteps=int(pretrain2vae),gansteps=int(pretrain2gan),save=True))
fname = rp + '_to_' + tgt + '.npy'
if fname in os.listdir(name+'/inferreddata/'):
print('Inference for '+tgt+' has been finished previously. Skip.')
premodel = repremodels2
histdic,xsemi,infer_model = fast_semi(name,reprepid,tgtpid,premodel,device=device,k=15,diagw=1.0, bulktype = bulktype,lr=inferlr,epochs=inferepochs,pseudocount=pseudocount)
print('Finished target sample single-cell inference')
return
[docs]def scinfer(name:str, representatives:str,cluster:str,bulktype:str='pseudobulk',
lambdad:float = 4.0,
pretrain1batch:int = 128,
pretrain1lr:float = 1e-3,
pretrain1vae:int = 100,
pretrain1gan:int = 100,
lambdabulkr:float = 1,
pretrain2lr:float = 1e-4,
pretrain2vae:int = 50,
pretrain2gan:int = 50,
inferepochs:int = 150,
lambdabulkt:float = 8.0,
inferlr:float = 2e-4,
pseudocount:float = 0.1,
ministages:int = 5,
k:int = 15,
device:str = 'cuda:0') -> None:
"""
Computationally infer the single-cell data of all non-representative samples (target samples) based on the cohort's bulk data and the representatives' single-cell data
Parameters
----------
name
The project name.
representatives
Path to a "txt" file containing the representative sample IDs (number)
cluster
Path to a "txt" file containing the cluster label information
bulktype
Pseudobulk or real bulk data
lambdad
Scaling factor for the discriminator loss.
pretrain1batch
The mini-batch size during the first pretrain stage.
pretrain1lr
The learning rate used in the first pretrain stage.
pretrain1vae
The number of epochs for training the VAE during the first pretrain stage.
pretrain1gan
The number of iterations for training GAN during the first pretrain stage.
lambdabulkr
Scaling factor for represenatative bulk loss for pretrain 2.
pretrain2lr
Pretrain 2 learning rate.
pretrain2vae
The number of epochs for training the VAE during the second pretrain stage.
pretrain2gan
The number of iterations for training the GAN during the second pretrain stage.
inferepochs
The number of epochs used for each mini-stage during inference.
lambdabulkt
Scaling factor for the initial target bulk loss.
inferlr
Infer stage learning rate.
ministages
Number of ministages during inference
k
The number of nearest neighbors used in cell graph.
device
Which device to use, e.g. 'cpu', 'cuda:0'.
pseudocount:
Pseudocount used when converting data from real bulk space to pseudobulk space
Returns
-------
None
Example
-------
>>> name = 'project_name'
>>> representatives = name + '/status/init_representatives.txt'
>>> cluster = name + '/status/init_cluster_labels.txt'
>>> scSemiProfiler.scinfer(name = name, representatives = representatives, cluster = cluster, bulktype = 'pseudobulk')
"""
if (os.path.isdir(name + '/inferreddata')) == False:
os.system('mkdir ' + name + '/inferreddata')
if (os.path.isdir(name + '/models')) == False:
os.system('mkdir ' + name + '/models')
if (os.path.isdir(name + '/tmp')) == False:
os.system('mkdir ' + name + '/tmp')
if (os.path.isdir(name + '/history')) == False:
os.system('mkdir '+ name + '/history')
device = device
diagw = 1.0
print('Start single-cell inference in cohort mode')
sids = []
f = open(name + '/sids.txt','r')
lines = f.readlines()
for l in lines:
sids.append(l.strip())
f.close()
repres = []
f=open(representatives,'r')
lines = f.readlines()
f.close()
for l in lines:
repres.append(int(l.strip()))
cluster_labels = []
f=open(cluster,'r')
lines = f.readlines()
f.close()
for l in lines:
cluster_labels.append(int(l.strip()))
#timing
pretrain1start = timeit.default_timer()
print('pretrain 1: representative reconstruction')
repremodels = []
for rp in repres:
sid = sids[rp]
# if exists, load model
modelfile = 'fast_reconst1_' + sid
path = name + '/models/fast_reconst1_'+sid
if modelfile in os.listdir(name + '/models'):
print('load existing pretrain 1 reconstruction model for '+sid)
adata,adj,variances,bulk,geneset_len = setdata(name,sid,device,k,diagw)
model = fastgenerator(variances,None,geneset_len,adata,n_hidden=256,n_latent=32,dropout_rate=0)
model.module.load_state_dict(torch.load(path))
repremodels.append(model)
#continue
else:
# otherwise, train model
repremodels.append(\
fastrecon(name=name,sid=sid, \
device=device,k=15,\
diagw=1,vaesteps=int(pretrain1vae),\
gansteps=int(pretrain1gan),save=True,path=None)\
)
# timing
pretrain1end = timeit.default_timer()
f=open('pretrain1time.txt','w')
f.write(str(pretrain1end-pretrain1start))
f.close()
#timing
pretrain2start = timeit.default_timer()
print('pretrain2: reconstruction with representative bulk loss')
repremodels2=[]
i=0
for rp in repres:
sid = sids[rp]
# if exists, load model
print('load existing model')
modelfile = 'fastreconst2_' + sid
path = name + '/models/fastreconst2_' + sid
if modelfile in os.listdir(name + '/models'):
print('load existing pretrain 2 model for ' + sid)
adata,adj,variances,bulk,geneset_len = setdata(name,sid,device,k,diagw)
model = fastgenerator(variances,None,geneset_len,adata,n_hidden=256,n_latent=32,dropout_rate=0)
model.module.load_state_dict(torch.load(path))
repremodels2.append(model)
#continue
else: repremodels2.append(reconst_pretrain2(name,sid,repremodels[i],\
device,k=15,diagw=1.0,vaesteps=int(pretrain2vae), gansteps=int(pretrain2gan),save=True))
i=i+1
#timing
#pretrain2end = timeit.default_timer()
#f=open('pretrain2time.txt','w')
#f.write(str(pretrain2end-pretrain2start))
#f.close()
#timing
#f = open('infertime.txt','w')
print('inference')
for i in range(len(sids)):
if i not in repres:
#timing
inferstart = timeit.default_timer()
tgtpid = i
reprepid = repres[cluster_labels[i]]
fname = sids[reprepid]+'_to_'+sids[tgtpid]+'.npy'
if fname in os.listdir(name+'/inferreddata/'):
print('Inference for '+sids[i]+' has been finished previously. Skip.')
continue
premodel = repremodels2[cluster_labels[i]]
histdic,xsemi,infer_model = fast_semi(name,reprepid,tgtpid,premodel,device=device,k=15,diagw=1.0,bulktype=bulktype,lr=inferlr,epochs=inferepochs,pseudocount=pseudocount,ministages = ministages)
#timing
inferend = timeit.default_timer()
#f.write(str(inferend-inferend)+'\n')
#timing
#f.close()
print('Finished single-cell inference')
return
def main():
parser = argparse.ArgumentParser(description="scSemiProfiler scinfer")
parser._action_groups.pop()
required = parser.add_argument_group('required arguments')
optional = parser.add_argument_group('optional arguments')
required.add_argument('--name',required=True,help="Project name (same as previous steps).")
required.add_argument('--representatives',required=True,help="Either a txt file including all the IDs of the representatives used in the current round of semi-profiling when running in cohort mode, or a single sample ID when running in single-sample mode.")
optional.add_argument('--cluster',required=False,default='na', help="A txt file specifying the cluster membership. Required when running in cohort mode.")
optional.add_argument('--targetid',required=False, default='na', help="Sample ID of the target sample when running in single-sample mode.")
optional.add_argument('--bulktype',required=False, default='real', help="Specify 'pseudo' for pseudobulk or 'real' for real bulk data. (Default: real)")
optional.add_argument('--lambdad',required=False, default='4.0', help="Scaling factor for the discriminator loss for training the VAE generator. (Default: 4.0)")
optional.add_argument('--pretrain1batch',required=False, default='128', help="Sample Batch Size of the first pretrain stage. (Default: 128)")
optional.add_argument('--pretrain1lr',required=False, default='1e-3', help="Learning rate of the first pretrain stage. (Default: 1e-3)")
optional.add_argument('--pretrain1vae',required=False, default='100', help = "The number of epochs for training the VAE generator during the first pretrain stage. (Default: 100)")
optional.add_argument('--pretrain1gan',required=False, default='100', help="The number of iterations for training the generator and discriminator jointly during the first pretrain stage. (Default: 100)")
optional.add_argument('--lambdabulkr',required=False, default='1.0', help="Scaling factor for the representative bulk loss. (Default: 1.0)")
optional.add_argument('--pretrain2lr',required=False, default='1e-4', help="The number of epochs for training the VAE generator during the second pretrain stage. (Default: 50)")
optional.add_argument('--pretrain2vae',required=False, default='50', help="Sample ID of the target sample when running in single-sample mode.")
optional.add_argument('--pretrain2gan',required=False, default='50', help="The number of iterations for training the generator and discriminator jointly during the second pretrain stage. (Default: 50)")
optional.add_argument('--inferepochs',required=False, default='150', help="The number of epochs for training the generator in each mini-stage during the inference. (Default: 150)")
optional.add_argument('--lambdabulkt',required=False, default='8.0', help="Scaling factor for the intial target bulk loss. (Default: 8.0)")
optional.add_argument('--inferlr',required=False, default='2e-4', help="Learning rate during the inference stage. (Default: 2e-4)")
args = parser.parse_args()
name = args.name
representatives = args.representatives
cluster = args.cluster
targetid = args.targetid
bulktype = args.bulktype
lambdad = float(args.lambdad)
pretrain1batch = int(args.pretrain1batch)
pretrain1lr = float(args.pretrain1lr)
pretrain1vae = int(args.pretrain1vae)
pretrain1gan = int(args.pretrain1gan)
lambdabulkr = float(args.lambdabulkr)
pretrain2lr = float(args.pretrain2lr)
pretrain2vae = int(args.pretrain2vae)
pretrain2gan = int(args.pretrain2gan)
inferepochs = int(args.inferepochs)
lambdabulkt = float(args.lambdabulkt)
inferlr = float(args.inferlr)
scinfer(name, representatives,cluster,targetid,bulktype,lambdad,pretrain1batch,pretrain1lr,pretrain1vae,pretrain1gan,lambdabulkr,pretrain2lr, pretrain2vae,pretrain2gan,inferepochs,lambdabulkt,inferlr)
if __name__=="__main__":
main()