diff --git a/code/expt/t-drive.py b/code/expt/t-drive.py index 4843cb4..6c485a1 100644 --- a/code/expt/t-drive.py +++ b/code/expt/t-drive.py @@ -75,14 +75,16 @@ def main(args): mae_u = np.zeros(len(data_info[d]['lmdks'])) mae_s = np.zeros(len(data_info[d]['lmdks'])) mae_a = np.zeros(len(data_info[d]['lmdks'])) + mae_d = np.zeros(len(data_info[d]['lmdks'])) mae_evt = 0 mae_usr = 0 for i, lmdk in enumerate(data_info[d]['lmdks']): # Find landmarks args.dist = data_info[d]['lmdks'][lmdk]['dist'] args.per = data_info[d]['lmdks'][lmdk]['per'] - lmdks = lmdk_lib.find_lmdks(seq, args)[:args.time] + lmdks = lmdk_lib.find_lmdks(seq, args) for bgt in bgt_conf: + s_d = 0 for _ in range(args.iter): # Skip rls_data_s, _ = lmdk_bgt.skip(seq, lmdks, bgt['epsilon']) @@ -96,6 +98,11 @@ def main(args): rls_data_a, _, _ = lmdk_bgt.adaptive(seq, lmdks, bgt['epsilon'], .5, .5) mae_a[i] += lmdk_bgt.mae(seq, rls_data_a)/args.iter + # # Dynamic + # rls_data_a, _, s_d_c = lmdk_bgt.dynamic(seq, lmdks, bgt['epsilon'], .5, .5) + # mae_d[i] += lmdk_bgt.mae(seq, rls_data_a)/args.iter + # s_d += s_d_c/args.iter + # Event if lmdk == 0: rls_data_evt, _ = lmdk_bgt.uniform_r(seq, lmdks, bgt['epsilon']) @@ -143,6 +150,14 @@ def main(args): label='Adaptive', linewidth=lmdk_lib.line_width ) + # x_offset += bar_width + # plt.bar( + # x_i + x_offset, + # mae_d, + # bar_width, + # label='Dynamic', + # linewidth=lmdk_lib.line_width + # ) path = str('../../rslt/bgt_cmp/' + d) # Plot legend diff --git a/code/lib/lmdk_lib.py b/code/lib/lmdk_lib.py index 3820b6e..e85ff95 100644 --- a/code/lib/lmdk_lib.py +++ b/code/lib/lmdk_lib.py @@ -856,7 +856,6 @@ def find_lmdks(usrs_data, args): ''' usrs_lmdks = np.empty((0,4), np.float32) traj_cur = 0 - lmdk_id = 0 usrs = np.unique(usrs_data[:,0]) for usr_i, usr in enumerate(usrs): # Initialize user's landmarks list @@ -888,11 +887,6 @@ def find_lmdks(usrs_data, args): per = abs(datetime.fromtimestamp(int(traj[i][3])) - datetime.fromtimestamp(int(traj[j][3]))).total_seconds()/60 # Check if enough time passed if per > args.per: - # usrs_id starts from 1 - lmdk_id += 1 - # Assign id to current landmark - for l in lmdk_cur: - l[0] = lmdk_id # Append current landmark lmdks += lmdk_cur # Continue checking from the current point @@ -910,7 +904,7 @@ def find_lmdks(usrs_data, args): def find_lmdks_seq(seq, lmdks): lmdks_seq = [] for i, p in enumerate(seq): - if any(np.equal(lmdks, p).all(1)): + if is_landmark(p, lmdks): lmdks_seq.append(i + 1) return np.array(lmdks_seq, dtype = int) @@ -1073,6 +1067,6 @@ def is_landmark(p, lmdks): Returns: True/False ''' - if len(lmdks) and any(np.equal(lmdks, p).all(1)): + if len(lmdks) > 0 and any(np.equal(lmdks, p).all(1)): return True return False