Commit f68ed031 authored by Spitzer's avatar Spitzer
Browse files

simulation code

parent dd40725b
clear all;
addpath('tools');
figure('units','normalized','outerposition',[0.2 0.1 0.6 0.85],'Color',[1 1 1]);
colormap(viridis());
nrows=4;
ncols=5;
fsize=13;
markersize=65;
lwidth=3.5;
a1=0.05; % default learning rate (alpha+)
as=0:0.001:0.1; % learning rates (alpha+/alpha-) in performance map
supp=0; % 1: also show supp Fig S1
titles={'Full feedback'; 'Partial feedback'};
mstr={'Q1','Q2';'Q1*','Q2*'};
for ex=1:2
if ex==1 % full feedback
det=1; % deterministic FB
dataf=['..' filesep '..' filesep 'data' filesep 'combined_mat' filesep 'exp1.mat'];
tauI=0.2;
eta=0;
selsub=1; % any sub
lims1=[-1 1];
lims2=[-0.1 1];
elseif ex==2 % partial feedback
det=0; % exp 4 has deterministic FB
dataf=['..' filesep '..' filesep 'data' filesep 'combined_mat' filesep 'exp4.mat'];
tauI=0.04;
eta=8;
selsub=2;
lims1=[-0.3 0.3];
lims2=[-0.1 0.6];
end
% load experiment data (same as in data analysis)
load(dataf);
[datamat, nsubs] = rmloperf(behav_result_mat,0.60);
% allocate
plotmaps=NaN(length(as),length(as),nsubs);
modRDM1=[]; modRDM2=[]; modRDM3=[]; modRDM4=[]; modRDM5=[];
VI1collect=cell(1,nsubs);
VI2collect=VI1collect;
VI3collect=VI1collect;
%% simulate
parfor n=1:nsubs
disp(['subject ' num2str(n) ' -started']);
[trials,FBtru,grT,nstims] = preptrials(datamat(:,:,:,n),det,0);
% switch off pair-level -> Q models only
gam=0; lam=100; tauP=100;
% symmetric (Q1)
[CP,Vitem]=modfunQP(a1,a1,eta,tauI,gam,lam,tauP,trials,FBtru);
VI1collect(n)={Vitem}; % store item value trajectories
modRDM1(:,:,n)=cpmatrix(trials,CP,1:nstims);
% alpha+/alpha- performance map
plotmaps(:,:,n) = simperf(as,eta,tauI,gam,lam,tauP,trials,FBtru,grT,0);
% asymmetric (Q2)
[CP,Vitem]=modfunQP(a1,0,eta,tauI,gam,lam,tauP,trials,FBtru);
VI2collect(n)={Vitem}; % store item value trajectories
modRDM2(:,:,n)=cpmatrix(trials,CP,1:nstims);
if ex==2 % for suppl. Fig S1
% symmetric model Q1 - partial fb
[CP,Vitem]=modfunQP(a1,a1,0,0.2,gam,lam,tauP,trials,FBtru);
VI3collect(n)={Vitem}; % store item value trajectories
modRDM3(:,:,n)=cpmatrix(trials,CP,1:nstims);
gam=1; tauP=0.1 % model P - partial fb
CP=modfunQP(0,0,0,100,gam,lam,tauP,trials,FBtru);
modRDM4(:,:,n)=cpmatrix(trials,CP,1:nstims);
lam=1; % enable 'linking' pairs recall (model Pi) - partial fb
CP=modfunQP(0,0,0,100,gam,lam,tauP,trials,FBtru);
modRDM5(:,:,n)=cpmatrix(trials,CP,1:nstims);
end
end
%% plot Figure 2
figex=ncols*2*(ex-1);
mvup=0.04*(ex-1);
plotQcp(VI1collect{selsub},modRDM1,nrows,ncols,figex+1,mvup,lims1,lwidth,fsize,mstr{ex,1});
h=plotmap(plotmaps,as,a1,nrows,ncols,figex+2,mvup,markersize,fsize,0);
h.Position(1)=h.Position(1)+0.055;
h.Position(2)=h.Position(2)-0.14+mvup;
h.Position(3)=h.Position(3)+0.01;
h.Position(4)=h.Position(4)+0.01;
t=title(titles{ex},'Fontsize',18,'Units','normalized');
t.Position(1)=t.Position(1)-0.3;
t.Position(2)=t.Position(2)+0.5;
plotQcp(VI2collect{selsub},modRDM2,nrows,ncols,figex+3.5,mvup,lims2,lwidth,fsize,mstr{ex,2});
end
%% supp Fig S1
if supp
figure('units','normalized','outerposition',[0.2 0.1 0.6 0.85],'Color',[1 1 1]);
colormap(viridis())
% Qvals and choice mat model Q1 (partial feedback)
plotQcp(VI3collect{selsub},modRDM3,nrows,ncols,1,mvup,[-1 1],lwidth,fsize,'Q1');
% choice mat model P (partial feedback)
h=subplot(nrows,ncols,2.5); % plot parameters
plotcp(modRDM4,fsize);
t=text(3.5,1,'P','Fontsize',fsize+2,'FontWeight','bold','Color',repmat(0.4,1,3));
h.Position(2)=h.Position(2)-0.08;
% choice mat model Pi (partial feedback)
h=subplot(nrows,ncols,3.5); % plot parameters
plotcp(modRDM5,fsize);
t=text(3.5,1,'Pi','Fontsize',fsize+2,'FontWeight','bold','Color',repmat(0.4,1,3));
h.Position(2)=h.Position(2)-0.08;
end
function Qmovie(startdelay) % in seconds
dataf=['..' filesep 'Data' filesep 'behavdat_TI_prob_EEGn35.mat'];
a1=0.05;
tauI=0.2;
eta=8;
%% get trials
load(dataf);
[datamat,nsubs] = rmloperf(behav_result_mat,0.60);
[trials, FBtru] = preptrials(datamat(:,:,:,1),1,0);
%% simulate
[~,Vitem]=modfunQP(a1,a1, 0 ,tauI,0,100,100,trials,FBtru);
Q(:,:,1)=Vitem; % model Q1
[~,Vitem]=modfunQP(a1,a1,eta,tauI,0,100,100,trials,FBtru);
Q(:,:,2)=Vitem; % model Q1*
[~,Vitem]=modfunQP(a1, 0,eta,tauI,0,100,100,trials,FBtru);
Q(:,:,3)=Vitem; % model Q2*
%% play movie
simtrials=200;
flip=1; % order items A-H with best item (A) first
playQmovie(Q,trials,FBtru,simtrials,flip,startdelay);
end
\ No newline at end of file
function [prefmat]=cpmatrix(trials,choices,val)
%% Choice matrix
prefmatA=NaN(length(val),length(val));
prefmatB=NaN(length(val),length(val));
for i=1:length(val)
for j=1:length(val)
index=find(trials(:,1)==i & trials(:,2)==j);
if i<j
prefmatA(j,i)=mean(choices(index));
elseif i>j
prefmatB(i,j)=1-mean(choices(index));
end
end
end
prefmat=(prefmatA+prefmatB)./2;
prefmat=fliplr(prefmat);
prefmat=flipud(prefmat);
end
\ No newline at end of file
function [CP,Vitem]=modfunQP(a1,a2,eta,tauI,gam,lam,tauP,trials,FBtru)
%% initialize
Vi=zeros(1,8); % flat item values for Q-learning
U=ones(1,7); % flat beta distributions for pair learning
L=U;
ntrials=length(trials);
nstims=length(unique(trials));
% allocate trial-by-trial matrices
Vitem=NaN(nstims,ntrials); % item values (8 x ntrials)
for i=1:length(trials)
%% Let the model make a choice
Vitem(:,i)=Vi; % item values at trial i
PairP(:,i)=U./(U+L); % pair probs at trial i
curTri=trials(i,:); % get current trial stimulus pair
v1=min(curTri); % sort so that v1 is always the lower-valued item
v2=max(curTri);
vdist=v2-v1;
% CP based on learned item values with noise tau_item (CPitem)
CPi=sigmoid(Vi(v2)-Vi(v1),tauI,0);
% CP based on learned pair-relations
tmpI=v1:v2-1; % indices of relations linking the pair
relPs=(U(tmpI))./(U(tmpI) + L(tmpI)); % prob. pair relations linking the pair
relP=sum(relPs-0.5)./vdist^(lam+1)+0.5; % recall of linking pairs as per lambda
CPp=sigmoid(relP-0.5,tauP,0); % CP with noise tau_pair (CPpair)
% whichever is more diagnostic (CPitem or CPpair) will guide choice
if abs(CPi-0.5)>abs(CPp-0.5)
transP=CPi;
elseif abs(CPi-0.5)<abs(CPp-0.5)
transP=CPp;
else % before any learning occurred, just guess
transP=0.5;
end
% flip back to match how items were presented in experiment
if diff(curTri)>0
CP(i)=transP;
else
CP(i)=1-transP;
end
%% Learn / update
if ~isnan(FBtru(i))
% item-level(Q-)learning)
vdiff=Vi(v2)-Vi(v1); % item-level value difference
vdiff=vdiff*eta;
if FBtru(i)==1
Vi(v1)=Vi(v1)+a2*(-1+vdiff-Vi(v1));
Vi(v2)=Vi(v2)+a1*( 1-vdiff-Vi(v2));
elseif FBtru(i)==0
Vi(v1)=Vi(v1)+a1*( 1+vdiff-Vi(v1));
Vi(v2)=Vi(v2)+a2*(-1-vdiff-Vi(v2));
end
% update pair-relational memory
if FBtru(i)==1
U(min(curTri))=U(min(curTri))+gam;
else
L(min(curTri))=L(min(curTri))+gam;
end
end
end
end
function playQmovie(V,trials,FBtru,simtrials,flip,startdelay)
h=figure('units','normalized','outerposition',[0.3 0.2 0.6 0.9],'Color',[1 1 1]);
h.OuterPosition(3:4)=h.OuterPosition(3:4)*0.68; % on 13" MB
titles={'model Q1'; 'model Q1*'; 'model Q2*'};
wincol=[84 130 53]./256;
loscol=[182 0 0]./256;
ftsize=12.5;
if flip
V=V(fliplr(1:size(V,1)),:,:);
end
nmodels=size(V,3);
%% initialize
for run=1:2
ps=(run-1).*3; % shift to second line of panels on second run
for j=1:nmodels
subplot(2,3,j+ps)
act=V(:,1,j);
% text placeholders
wintx{j+ps}=text(0.5,-0.40, '','Color',wincol,'FontSize',ftsize,'HorizontalAlignment','left'); hold on
lostx{j+ps}=text(0.5,-0.55, '','Color',loscol,'FontSize',ftsize,'HorizontalAlignment','left');
reptx{j+ps}=text(4.5,0.50, '','Color',[0.5 0.5 0.5],'FontSize',ftsize,'HorizontalAlignment','center');
% white disks in background
fb{j+ps}=scatter(1:8 ,act, 200, ones(8,3), 'filled'); hold on
% line plot
pl{j+ps}=plot(1:8,act,'-k');
set(gca,'FontSize',ftsize-2)
% items on top
dotcolors=plasma(8);
if flip
dotcolors=flipud(dotcolors);
end
sc{j+ps}=scatter([1:8], act, 65, dotcolors, 'filled');
stims=trials(j,:);
tx{j+ps}=text(0.5,0.8,'trial 0','FontSize',ftsize);
xticks(1:8);
labels=itemticks(1);
if ~flip
labels=labels(fliplr(1:8));
end
xticklabels(labels);
a = get(gca,'XTickLabel');
set(gca,'XTickLabel',a,'fontsize',ftsize-1)
xlim([0.1 8.9]);
%xlabel('item')
ylim([-0.7 0.9]);
yticks([-0.5 0 0.5]);
if j==1
ylabel('value(Q)')
end
if run==1
title(titles{j},'Fontsize',ftsize);
end
end
%% before sim starts
if run==1
pause(startdelay);
end
if run==2
wintx{1+ps}.String='\bf{winner}';
lostx{1+ps}.String='\bf{loser}';
pause(1)
for j=1:nmodels
reptx{j+ps}.String='\bf{slow motion}';
end
pause(2);
for j=1:nmodels
reptx{j+ps}.String='';
end
pause(1);
%simtrials=100;
end
% for j=1:nmodels
% tx{j+ps}.String=['trial ' num2str(1)];
% end
for i=1:simtrials
updcol=ones(8,3);
if ~isnan(FBtru(i)) & run==2
if FBtru(i)==1
indW=max(trials(i,:)); % winning item
indL=min(trials(i,:)); % losing item
elseif FBtru(i)==0
indW=min(trials(i,:)); % winning item
indL=max(trials(i,:)); % losing item
end
if flip
indW=-indW+9;
indL=-indL+9;
end
updcol(indW,:)=wincol;
updcol(indL,:)=loscol;
end
% highlight the to-be-updated items
for j=1:nmodels
tx{j+ps}.String=['trial ' num2str(i)];
act=V(:,i,j);
fb{j+ps}.YData=act;
fb{j+ps}.CData=updcol;
end
if ~isnan(FBtru(i)) & run==2
pause(0.2);
end
% update
for j=1:nmodels
act=V(:,i+1,j);
fb{j+ps}.YData=act;
pl{j+ps}.YData=act;
sc{j+ps}.YData=act;
end
if ~isnan(FBtru(i))
pause(0.1);
if run==2
pause(0.4);
end
end
if run==2
pause(0.1);
end
end
pause(1)
end
disp('fin');
end
function plotQcp(Qvals,choicemat,nrows,ncols,panel,mvup,lims,lwidth,fsize,mstr);
% Q value trajectories
h=subplot(nrows,ncols,panel);
set(0,'DefaultAxesColorOrder',plasma(8))';
plot(Qvals','Linewidth',lwidth);
set(gca,'FontSize',fsize)
ylim(lims);
set(gca,'box','off')
set(gca,'color','none')
h.Position(3)=h.Position(3)-0.01;
h.Position(4)=h.Position(4)-0.04;
%h.Position(1)=h.Position(1)-0.005;
h.Position(2)=h.Position(2)-0.03+mvup;
xlabel('trial')
ylabel('Value (Q)');
t=text(0.05,0.9,mstr,'Units','normalized','FontSize',fsize+2);
t.FontWeight='bold';
t.Color=repmat(0.4,1,3); % grey
if strmatch('Q2',mstr)
t.Color='r';
end
% choice mat
h=subplot(nrows,ncols,ncols+panel);
plotcp(choicemat,fsize)
h.Position(1)=h.Position(1)+0.002;
h.Position(2)=h.Position(2)+mvup;
end
function plotcp(choicemat,fsize)
mat=mean(choicemat,3);
pic=imagesc(mat');
set(gca,'visible','off')
pic.AlphaData=alphaout(8,0);
labels=itemticks(0);
for i=1:length(labels)-1
text(i-0.55,9,labels(i),'Fontsize',fsize+1);
text(-0.1,i+1,labels(i+1),'Fontsize',fsize+1,'HorizontalAlignment','center');
end
end
function h=plotmap(maps,as,a1,nrows,ncols,panel,mvup,markersize,fsize,colbar)
h=subplot(nrows,ncols,panel);
mmap=mean(maps,3);
imagesc(as,as,mmap(:,:));
if colbar
hcb=colorbar;
hcb.Title.String='p(correct)';
end
hold on
plot([as(1) as(end)],[as(1) as(end)],'w--','Linewidth',2);
xlabel('winner (\alpha^{+})')
ylabel('loser (\alpha^{-})')
set(gca,'Ydir','normal');
set(gca,'FontSize',fsize)
set(gca,'box','off')
hold on
xticks([0 0.05 0.1]);
yticks([0 0.05 0.1]);
if markersize>0
scatter(a1, a1,markersize, 'o','MarkerFaceColor', 'k','MarkerEdgeColor', 'k');
scatter(a1, 0 ,markersize, '^','MarkerFaceColor', 'r','MarkerEdgeColor', 'r');
end
end
function map = simperf(as,eta,tauI,gam,lam,tauP,trials,FBtru,grT,transonly)
ntrials=length(trials);
lasthalf=ceil(ntrials/2); % 1st trial in 2nd half of blocks
nas=length(as);
map=NaN(nas,nas);
for i=1:nas
for j=1:nas
CP=modfunQP(as(i),as(j),eta,tauI,gam,lam,tauP,trials,FBtru);
if transonly % accuracy on non-neighbours only
map(i,j)=1-mean(abs(grT(isnan(FBtru))-CP(isnan(FBtru))));
else
map(i,j)=1-mean(abs(grT(lasthalf:end)-CP(lasthalf:end)));
end
end
end
end
function alphadat=alphaout(k,D)
alphadat=ones(k,k);
if D % keep diagonal
for i=1:k
for j=1:k
if i<j
alphadat(i,j)=0;
end
end
end
else % do not keep diagonal
for i=1:k
for j=1:k
if i<=j
alphadat(i,j)=0;
end
end
end
end
end
\ No newline at end of file
function ticklabels_new=itemticks(vis)
labels={'\bf{A}';'\bf{B}';'\bf{C}';'\bf{D}';'\bf{E}';'\bf{F}';'\bf{G}';'\bf{H}'};
color=plasma(8);
if vis
tmp=plasma(40); % replace hardly visible yellow by slightly darker
color(end,:)=tmp(end-1,:);
end
color=flipud(color);
for i=1:length(labels)
actcolor=num2str(color(i,:));
ticklabels_new{i} = ['\color[rgb]{' actcolor '} ' labels{i}];
end
end
\ No newline at end of file
function cm_data=plasma(m)
cm = [[ 5.03832136e-02, 2.98028976e-02, 5.27974883e-01],
[ 6.35363639e-02, 2.84259729e-02, 5.33123681e-01],
[ 7.53531234e-02, 2.72063728e-02, 5.38007001e-01],
[ 8.62217979e-02, 2.61253206e-02, 5.42657691e-01],
[ 9.63786097e-02, 2.51650976e-02, 5.47103487e-01],
[ 1.05979704e-01, 2.43092436e-02, 5.51367851e-01],
[ 1.15123641e-01, 2.35562500e-02, 5.55467728e-01],
[ 1.23902903e-01, 2.28781011e-02, 5.59423480e-01],
[ 1.32380720e-01, 2.22583774e-02, 5.63250116e-01],
[ 1.40603076e-01, 2.16866674e-02, 5.66959485e-01],
[ 1.48606527e-01, 2.11535876e-02, 5.70561711e-01],
[ 1.56420649e-01, 2.06507174e-02, 5.74065446e-01],
[ 1.64069722e-01, 2.01705326e-02, 5.77478074e-01],
[ 1.71573925e-01, 1.97063415e-02, 5.80805890e-01],
[ 1.78950212e-01, 1.92522243e-02, 5.84054243e-01],
[ 1.86212958e-01, 1.88029767e-02, 5.87227661e-01],
[ 1.93374449e-01, 1.83540593e-02, 5.90329954e-01],
[ 2.00445260e-01, 1.79015512e-02, 5.93364304e-01],
[ 2.07434551e-01, 1.74421086e-02, 5.96333341e-01],
[ 2.14350298e-01, 1.69729276e-02, 5.99239207e-01],
[ 2.21196750e-01, 1.64970484e-02, 6.02083323e-01],
[ 2.27982971e-01, 1.60071509e-02, 6.04867403e-01],
[ 2.34714537e-01, 1.55015065e-02, 6.07592438e-01],
[ 2.41396253e-01, 1.49791041e-02, 6.10259089e-01],
[ 2.48032377e-01, 1.44393586e-02, 6.12867743e-01],
[ 2.54626690e-01, 1.38820918e-02, 6.15418537e-01],
[ 2.61182562e-01, 1.33075156e-02, 6.17911385e-01],
[ 2.67702993e-01, 1.27162163e-02, 6.20345997e-01],
[ 2.74190665e-01, 1.21091423e-02, 6.22721903e-01],
[ 2.80647969e-01, 1.14875915e-02, 6.25038468e-01],
[ 2.87076059e-01, 1.08554862e-02, 6.27294975e-01],
[ 2.93477695e-01, 1.02128849e-02, 6.29490490e-01],
[ 2.99855122e-01, 9.56079551e-03, 6.31623923e-01],
[ 3.06209825e-01, 8.90185346e-03, 6.33694102e-01],
[ 3.12543124e-01, 8.23900704e-03, 6.35699759e-01],
[ 3.18856183e-01, 7.57551051e-03, 6.37639537e-01],
[ 3.25150025e-01, 6.91491734e-03, 6.39512001e-01],
[ 3.31425547e-01, 6.26107379e-03, 6.41315649e-01],
[ 3.37683446e-01, 5.61830889e-03, 6.43048936e-01],
[ 3.43924591e-01, 4.99053080e-03, 6.44710195e-01],
[ 3.50149699e-01, 4.38202557e-03, 6.46297711e-01],
[ 3.56359209e-01, 3.79781761e-03, 6.47809772e-01],
[ 3.62553473e-01, 3.24319591e-03, 6.49244641e-01],
[ 3.68732762e-01, 2.72370721e-03, 6.50600561e-01],
[ 3.74897270e-01, 2.24514897e-03, 6.51875762e-01],
[ 3.81047116e-01, 1.81356205e-03, 6.53068467e-01],
[ 3.87182639e-01, 1.43446923e-03, 6.54176761e-01],
[ 3.93304010e-01, 1.11388259e-03, 6.55198755e-01],
[ 3.99410821e-01, 8.59420809e-04, 6.56132835e-01],
[ 4.05502914e-01, 6.78091517e-04, 6.56977276e-01],
[ 4.11580082e-01, 5.77101735e-04, 6.57730380e-01],
[ 4.17642063e-01, 5.63847476e-04, 6.58390492e-01],
[ 4.23688549e-01, 6.45902780e-04, 6.58956004e-01],
[ 4.29719186e-01, 8.31008207e-04, 6.59425363e-01],
[ 4.35733575e-01, 1.12705875e-03, 6.59797077e-01],
[ 4.41732123e-01, 1.53984779e-03, 6.60069009e-01],
[ 4.47713600e-01, 2.07954744e-03, 6.60240367e-01],
[ 4.53677394e-01, 2.75470302e-03, 6.60309966e-01],
[ 4.59622938e-01, 3.57374415e-03, 6.60276655e-01],
[ 4.65549631e-01, 4.54518084e-03, 6.60139383e-01],
[ 4.71456847e-01, 5.67758762e-03, 6.59897210e-01],
[ 4.77343929e-01, 6.97958743e-03, 6.59549311e-01],
[ 4.83210198e-01, 8.45983494e-03, 6.59094989e-01],
[ 4.89054951e-01, 1.01269996e-02, 6.58533677e-01],
[ 4.94877466e-01, 1.19897486e-02, 6.57864946e-01],
[ 5.00677687e-01, 1.40550640e-02, 6.57087561e-01],
[ 5.06454143e-01, 1.63333443e-02, 6.56202294e-01],
[ 5.12206035e-01, 1.88332232e-02, 6.55209222e-01],
[ 5.17932580e-01, 2.15631918e-02, 6.54108545e-01],
[ 5.23632990e-01, 2.45316468e-02, 6.52900629e-01],