JDK源码阅读之ThreadLocal

写在前面

This class provides thread-local variables. These variables differ from their normal counterparts in that each thread that accesses one (via its {@code get} or {@code set} method) has its own, independently initialized copy of the variable. {@code ThreadLocal} instances are typically private static fields in classes that wish to associate state with a thread (e.g.,a user ID or Transaction ID).

源码分析自 JDK 1.8.171

ThreadLocal 是一个用于暂存线程本地变量的工具数据结构类。

使用示例代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
public class ThreadId {
// Atomic integer containing the next thread ID to be assigned
private static final AtomicInteger nextId = new AtomicInteger(0);

// Thread local variable containing each thread's ID
private static final ThreadLocal<Integer> threadId =
new ThreadLocal<Integer>() {
@Override
protected Integer initialValue() {
// 如果get时,没有set,则设置一个初始值并返回
return nextId.getAndIncrement();
}
};

// Returns the current thread's unique ID, assigning it if necessary
public static int get() {
return threadId.get();
}
}

概述

ThreadLocal 能够存储当前线程可见的本地变量。在ThreadLocal中有一个 ThreadLocalMap 静态类,这个类才是具体存储变量的数据结构。在Thread中有一个声明为 ThreadLocal.ThreadLocalMap threadLocals = null; 的变量,这个变量就是存储的当前线程可见的本地变量,这个threadlocals由ThreadLocal类维护。

内部变量

  • private final int threadLocalHashCode = nextHashCode(); // ThreadLocal在线程Thread的hash值
  • private static AtomicInteger nextHashCode = new AtomicInteger(); // 下一个ThreadLocal的hash值,一个线程可能会有多个ThreadLocal
  • private static final int HASH_INCREMENT = 0x61c88647; // 两个ThreadLocal之间hash差值

构造方法

  • public ThreadLocal() // 默认构造方法
1
2
3
4
5
6
/**
* Creates a thread local variable.
* @see #withInitial(java.util.function.Supplier)
*/
public ThreadLocal() {
}
  • public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) // 带初始值的构造方法
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
return new SuppliedThreadLocal<>(supplier);
}

static final class SuppliedThreadLocal<T> extends ThreadLocal<T> {

private final Supplier<? extends T> supplier;

SuppliedThreadLocal(Supplier<? extends T> supplier) {
this.supplier = Objects.requireNonNull(supplier);
}

@Override
protected T initialValue() {
return supplier.get();
}
}

initialValue方法是在调用get()方法的时候可能会调用的
**

ThreadLocalMap

ThreadLocalMap是一个hash map。这个map的 Entry是在ThreadLocalMap内部的一个静态类,Entry 继承了 WeakReference<ThreadLocal<?>>

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
static class ThreadLocalMap {
/**
* 注意这里的entry.get() 为null的话,则ThreadLocal为null,对应的这个entry也可以清理了,则这个table的这个entry也可以被清理了
*/
static class Entry extends WeakReference<ThreadLocal<?>> {
/** The value associated with this ThreadLocal. */
Object value;

Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}

private static final int INITIAL_CAPACITY = 16; // 默认大小,与HashMap一样须为2的幂次方

private Entry[] table; // entry容器

private int size = 0; // table中entry数量

private int threshold; // Default to 0, size>=threshold就应该扩容了

private void setThreshold(int len) {
threshold = len * 2 / 3;
}

// (i + 1) % len
private static int nextIndex(int i, int len) {
return ((i + 1 < len) ? i + 1 : 0);
}

// (i - 1) % len
private static int prevIndex(int i, int len) {
return ((i - 1 >= 0) ? i - 1 : len - 1);
}

ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
table = new Entry[INITIAL_CAPACITY];
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
table[i] = new Entry(firstKey, firstValue);
size = 1;
// INITIAL_CAPACITY * load_factor
setThreshold(INITIAL_CAPACITY);
}

private Entry getEntry(ThreadLocal<?> key) {
int i = key.threadLocalHashCode & (table.length - 1);
Entry e = table[i];
if (e != null && e.get() == key)
return e;
else
// 存在多个ThreadLocal的时候
return getEntryAfterMiss(key, i, e);
}

private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
Entry[] tab = table;
int len = tab.length;

while (e != null) {
ThreadLocal<?> k = e.get();
if (k == key)
return e;
if (k == null)
// 遇到key为空就清理
expungeStaleEntry(i);
else
// 继续向下找
i = nextIndex(i, len);
e = tab[i];
}
return null;
}

private void set(ThreadLocal<?> key, Object value) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);

for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
ThreadLocal<?> k = e.get();

// 就是当前thread local所占位置,直接替换value
if (k == key) {
e.value = value;
return;
}

// key为空,那么直接替换掉entry
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}

// table上的i位置的entry为空
tab[i] = new Entry(key, value);
int sz = ++size;
// 如果i往后并没有entry被清理掉,并且当前的size已经大于threshold了,则rehash一下下
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}

private void replaceStaleEntry(ThreadLocal<?> key, Object value,
int staleSlot) {
Entry[] tab = table;
int len = tab.length;
Entry e;

// 往前找到最远的一个key被回收的entry
int slotToExpunge = staleSlot;
for (int i = prevIndex(staleSlot, len);
(e = tab[i]) != null;
i = prevIndex(i, len))
if (e.get() == null)
slotToExpunge = i;

// 往后找,直到entry为null
for (int i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();

// key相等,就是当前thread local
if (k == key) {
e.value = value;

// 将i和staleSlot互换
tab[i] = tab[staleSlot];
tab[staleSlot] = e;

// Start expunge at preceding stale entry if it exists
// 就是清理
if (slotToExpunge == staleSlot)
slotToExpunge = i;
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
return;
}

if (k == null && slotToExpunge == staleSlot)
slotToExpunge = i;
}

tab[staleSlot].value = null;
tab[staleSlot] = new Entry(key, value);

if (slotToExpunge != staleSlot)
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}

// 清理staleSlot位置上的entry及key value,并且从stale Slot往后rehash
private int expungeStaleEntry(int staleSlot) {
Entry[] tab = table;
int len = tab.length;

// expunge entry at staleSlot
// 清理table的staleSlot位置
tab[staleSlot].value = null;
tab[staleSlot] = null;
size--;

// Rehash until we encounter null
Entry e;
int i;
// rehash,往后遍历,直到遇到table 的 entry 为null
for (i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
// 如果entry的key为null,手动清理一下子,因为key为weakReference,所以可能存在key被清理的情况
if (k == null) {
e.value = null;
tab[i] = null;
size--;
} else {
// 如果entry的key不等于null,找到当前时刻i对应的的entry
int h = k.threadLocalHashCode & (len - 1);
// 如果不相等,一个线程多个thread local时
if (h != i) {
// 那么清理i位置的entry,但是value、key仍在
tab[i] = null;

// 将i往后挪动到h,到一个空位上
while (tab[h] != null)
h = nextIndex(h, len);
tab[h] = e;
}
}
}
return i;
}

// 如果i位置的entry的key已被回收,则清理
// 如果i往后存在key被回收的entry,搭着清一下,助人为乐
// 注意O(log2(len))
private boolean cleanSomeSlots(int i, int n) {
boolean removed = false;
Entry[] tab = table;
int len = tab.length;
do {
i = nextIndex(i, len);
Entry e = tab[i];
// 如果i位置entry不为空,但是key已经被回收
if (e != null && e.get() == null) {
// 则清理,并重置遍历状态
n = len;
removed = true;
// 清理i位置的entry,并且从i开始往后rehash,直到遇到null
i = expungeStaleEntry(i); // 返回值为i后的table的一个空位
}
} while ( (n >>>= 1) != 0); // 最后O(log2(len))
return removed;
}

private void rehash() {
// 先清理key已经被回收的entry
expungeStaleEntries();

// Use lower threshold for doubling to avoid hysteresis
// len*2/3*3/4 result is 1/2
if (size >= threshold - threshold / 4)
resize();
}

private void remove(ThreadLocal<?> key) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
if (e.get() == key) {
// key = null
e.clear();
// 清理i位置上entry的key and value,并且往后rehash
expungeStaleEntry(i);
return;
}
}
}

// 清理废物entry(key已经被回收的)
private void expungeStaleEntries() {
Entry[] tab = table;
int len = tab.length;
for (int j = 0; j < len; j++) {
Entry e = tab[j];
// 如果key已经被回收,则清理当前entry
if (e != null && e.get() == null)
expungeStaleEntry(j);
}
}

// ... 省略其他代码

}

#

源码分析

ThreadLocal的一般用法就是set和get方法,就从这两个方法作为入口分析一下逻辑流程。

公共逻辑

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
ThreadLocalMap getMap(Thread t) {
// 由类Thread维护一个ThreadLocalMap
return t.threadLocals;
}

void createMap(Thread t, T firstValue) {
// 初始化ThreadLocalMap
t.threadLocals = new ThreadLocalMap(this, firstValue);
}

public void remove() {
ThreadLocalMap m = getMap(Thread.currentThread());
if (m != null)
// 在当前线程Thread中维护的TreahdLocalMap中,remove以当前threadLocal作为key的键值对
m.remove(this);
}

private T setInitialValue() {
// 初始值
T value = initialValue();
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
return value;
}

set

1
2
3
4
5
6
7
8
9
10
public void set(T value) {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
// 将<K,V>为<threadLocal,value>对put到threadLocalMap中
map.set(this, value);
else
// 第一次put值时才会初始化threadLocalMap
createMap(t, value);
}

get

1
2
3
4
5
6
7
8
9
10
11
12
13
14
public T get() {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null) {
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
// 为空的时候允许设置个初始值
return setInitialValue();
}

总结

  • ThreadLocal数据存储由一个ThreadLocalMap的数据结构完成,每个Thread类维护一个ThreadLocalMap
  • ThreadLocalMap是一个Key为ThreadLocal的map,内部是由一个Entry[]数组存储元素,Entry继承WeakRefenrence<ThreadLocal<?>>
  • 当一个Thread存在多个ThreadLocal时,ThreadLocalMap可能存储扩容的情况
  • Thread类内部维护的ThreadLocalMap的初始化是lazy的,就是说只有当get或者set的时候,才会去初始化,并且允许带初始值(get的时候)
  • ThreadLocalMap的Entry继承WeakReference<ThreadLocal<?>>,由它的类声明,可以理解为key是一个弱引用,但是value确实一个强引用,所以在一些情况下,key引用被回收了,value引用是未被回收的结,结合源码,在set、get和remove的时候都会涉及到包含这种情况的清理,试想在一段时间内,不涉及到set、get、remove操作,那么大概率还是存在内存泄漏的,所以一般建议是使用后手动remove以避免内存泄漏
-------------The End-------------