无需得到最佳epoch,计算rank-1
import os
def get_best_rank(log_file):
os.system('grep Results -A 6 '+log_file+' > get_old_result.txt')
all_r1, all_r5, all_r10, all_r20 = [], [], [], []
with open('get_old_result.txt') as f:
while True:
line = f.readline()
if 'Rank-1 ' in line:
all_r1.append(float(line.split(': ')[-1].strip('%\n')))
elif 'Rank-5 ' in line:
all_r5.append(float(line.split(': ')[-1].strip('%\n')))
elif 'Rank-10 ' in line:
all_r10.append(float(line.split(': ')[-1].strip('%\n')))
elif 'Rank-20 ' in line:
all_r20.append(float(line.split(': ')[-1].strip('%\n')))
if not line:
break
all_max = [0,0,0,0]
while True:
if all_r1 == []:
break
max_r1 = max(all_r1)
max_index = all_r1.index(max_r1)
max_r5 = all_r5[max_index]
max_r10 = all_r10[max_index]
max_r20 = all_r20[max_index]
if max_r1 > all_max[0]:
all_max = [max_r1, max_r5, max_r10, max_r20]
elif max_r1 == all_max[0]:
if max_r5 > all_max[1]:
all_max[1], all_max[2], all_max[3] = max_r5, max_r10, max_r20
elif max_r5 == all_max[1]:
if max_r10 > all_max[2]:
all_max[2], all_max[3] = max_r10, max_r20
elif max_r10 == all_max[2]:
if max_r20 > all_max[3]:
all_max[3] = max_r20
else:
break
del all_r1[max_index]
del all_r5[max_index]
del all_r10[max_index]
del all_r20[max_index]
split_id = int(log_file.split('/')[0].split('_')[-1])
all_max.insert(0, split_id)
return all_max
all_log_file = []
for dir in os.listdir('.'):
if os.path.isdir(dir):
all_log_file.append(dir+'/log_train.txt')
#all_log_file = sorted(all_log_file)\
all_log_file.sort(key=lambda x:int(x[:-14][4:]))
sum_r1, sum_r5, sum_r10, sum_r20 = 0, 0, 0, 0
with open('res.txt','w') as f:
f.write('Sp_ID Rank-1 Rank-5 Rank-10 Rank-20\n')
for lf in all_log_file:
one_res = get_best_rank(lf)
sum_r1 += one_res[1]
sum_r5 += one_res[2]
sum_r10 += one_res[3]
sum_r20 += one_res[4]
out = str(one_res).strip('[').strip(']').replace(',',' ')
f.write(out+'\n')
avg_r1, avg_r5, avg_r10, avg_r20 = round(sum_r1/len(all_log_file), 3), round(sum_r5/len(all_log_file), 3), round(sum_r10/len(all_log_file), 3), round(sum_r20/len(all_log_file), 3)
avg = [ avg_r1, avg_r5, avg_r10, avg_r20]
out = str(avg).strip('[').strip(']').replace(',',' ')
f.write('avg '+out+'\n')
os.system('column -t res.txt > result.txt')
os.system('rm get_old_result.txt res.txt')
评论
发表评论