deep_person_reid中数据读取的方式
deep person reid中的dataset_loader之前一直没有认真读代码,并不是很容易看懂。这几天看到[1]的代码,其中数据读取的部分是[1]的作者自己写的,比较简单,我就看了下,然后与deep person reid中的dataset_loade做了对比,发现我能看懂了。原始有三种采样模式:evenly、random和all。
evenly
evenly是在整个tracklet中,每隔几帧选取一帧,这个间隔是根据这样计算的:先去掉余数部分,再以等间隔取数据。比如从17帧取8帧,就是以间隔 ( 17 - 17 % 8 ) / 8 = 2取;从27帧中取8帧,就是以间隔 ( 27 - 27 % 8 ) / 8 = 3取。
random
random是在tracklet中随机选取seq_len张图片,然后重新排序。在[2]中的random是不同的方式,[2]中的方式与M3D中的方法是相同的,随机选取一个开始位置,然后连续选取seq_len张图片。之前的实验证明选取连续的seq_len张图片的效果比较好。
all
tracklet中的所有图片都选取进来
dense
dense是[2]中的方法,与M3D的相同,都是重复最后一个序列的元素,如[0,1,2,3,0,1]。
class VideoDataset(Dataset):
"""Video Person ReID Dataset.
Note batch data has shape (batch, seq_len, channel, height, width).
"""
_sample_methods = ['evenly', 'random', 'all']
def __init__(self, dataset, seq_len=15, sample_method='evenly', transform=None):
self.dataset = dataset
self.seq_len = seq_len
self.sample_method = sample_method
self.transform = transform
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
img_paths, pid, camid = self.dataset[index]
num = len(img_paths)
if self.sample_method == 'random':
"""
Randomly sample seq_len items from num items,
if num is smaller than seq_len, then replicate items
"""
indices = np.arange(num)
replace = False if num >= self.seq_len else True
indices = np.random.choice(indices, size=self.seq_len, replace=replace)
# sort indices to keep temporal order (comment it to be order-agnostic)
indices = np.sort(indices)
elif self.sample_method == 'evenly':
"""
Evenly sample seq_len items from num items.
"""
if num >= self.seq_len:
num -= num % self.seq_len
indices = np.arange(0, num, num/self.seq_len)
else:
# if num is smaller than seq_len, simply replicate the last image
# until the seq_len requirement is satisfied
indices = np.arange(0, num)
num_pads = self.seq_len - num
indices = np.concatenate([indices, np.ones(num_pads).astype(np.int32)*(num-1)])
assert len(indices) == self.seq_len
elif self.sample_method == 'all':
"""
Sample all items, seq_len is useless now and batch_size needs to be set to 1.
"""
indices = np.arange(num)
else:
raise ValueError('Unknown sample method: {}. Expected one of {}'.format(self.sample_method, self._sample_methods))
imgs = []
for index in indices:
img_path = img_paths[int(index)]
img = read_image(img_path)
if self.transform is not None:
img = self.transform(img)
img = img.unsqueeze(0)
imgs.append(img)
imgs = torch.cat(imgs, dim=0)
return imgs, pid, camid
dense 代码
elif self.sample == 'dense':
"""
Sample all frames in a video into a list of clips, each clip contains seq_len frames, batch_size needs to be set to 1.
This sampling strategy is used in test phase.
"""
cur_index=0
frame_indices = range(num)
indices_list=[]
while num-cur_index > self.seq_len:
indices_list.append(frame_indices[cur_index:cur_index+self.seq_len])
cur_index+=self.seq_len
last_seq=frame_indices[cur_index:]
for index in last_seq:
if len(last_seq) >= self.seq_len:
break
last_seq.append(index)
indices_list.append(last_seq)
imgs_list=[]
for indices in indices_list:
imgs = []
for index in indices:
index=int(index)
img_path = img_paths[index]
img = read_image(img_path)
if self.transform is not None:
img = self.transform(img)
#img = img[1:,:,:] # delete the original z channel, which is all zeros before transform.
img = img.unsqueeze(0)
imgs.append(img)
imgs = torch.cat(imgs, dim=0)
#imgs=imgs.permute(1,0,2,3)
imgs_list.append(imgs)
imgs_array = torch.stack(imgs_list)
[1] Multi-scale 3D Convolution Network for Video Based Person Re-Identification
[2] Revisiting Temporal Modeling for Video-based Person ReID
评论
发表评论