code: Checking for rows correctly

This commit is contained in:
Manos Katsomallos 2021-09-29 12:58:19 +02:00
parent 4ea7fc2054
commit 5eb60ab286

View File

@ -370,7 +370,7 @@ def adaptive_cont(seq, lmdks, epsilon, inc_rt, dec_rt):
skipped = 0 skipped = 0
for i, p in enumerate(seq): for i, p in enumerate(seq):
# Check if current point is a landmark # Check if current point is a landmark
r = p[2] in lmdks r = any((lmdks[:]==p).all(1))
if r: if r:
lmdk_cur += 1 lmdk_cur += 1
if lmdk_lib.should_sample(samp_rt) or i == 0: if lmdk_lib.should_sample(samp_rt) or i == 0:
@ -447,7 +447,7 @@ def skip_cont(seq, lmdks, epsilon):
rls_data = [None]*len(seq) rls_data = [None]*len(seq)
for i, p in enumerate(seq): for i, p in enumerate(seq):
# Check if current point is a landmark # Check if current point is a landmark
r = p[2] in lmdks r = any((lmdks[:]==p).all(1))
# Add noise # Add noise
o = lmdk_lib.randomized_response(r, bgts[i]) o = lmdk_lib.randomized_response(r, bgts[i])
if r: if r:
@ -636,7 +636,7 @@ def uniform_cont(seq, lmdks, epsilon):
# Budgets # Budgets
bgts = uniform(seq, lmdks, epsilon) bgts = uniform(seq, lmdks, epsilon)
for i, p in enumerate(seq): for i, p in enumerate(seq):
r = p[2] in lmdks r = any((lmdks[:]==p).all(1))
# [original, perturbed] # [original, perturbed]
rls_data[i] = [r, lmdk_lib.randomized_response(r, bgts[i])] rls_data[i] = [r, lmdk_lib.randomized_response(r, bgts[i])]
return rls_data, bgts return rls_data, bgts