deep_person_reid中数据读取的方式

deep_person_reid中数据读取的方式

deep person reid中的dataset_loader之前一直没有认真读代码,并不是很容易看懂。这几天看到[1]的代码,其中数据读取的部分是[1]的作者自己写的,比较简单,我就看了下,然后与deep person reid中的dataset_loade做了对比,发现我能看懂了。原始有三种采样模式:evenly、random和all。

evenly

evenly是在整个tracklet中,每隔几帧选取一帧,这个间隔是根据这样计算的:先去掉余数部分,再以等间隔取数据。比如从17帧取8帧,就是以间隔 ( 17 - 17 % 8 ) / 8 = 2取;从27帧中取8帧,就是以间隔 ( 27 - 27 % 8 ) / 8 = 3取。

random

random是在tracklet中随机选取seq_len张图片,然后重新排序。在[2]中的random是不同的方式,[2]中的方式与M3D中的方法是相同的,随机选取一个开始位置,然后连续选取seq_len张图片。之前的实验证明选取连续的seq_len张图片的效果比较好。

all

tracklet中的所有图片都选取进来

dense

dense是[2]中的方法,与M3D的相同,都是重复最后一个序列的元素,如[0,1,2,3,0,1]。

class VideoDataset(Dataset):
    """Video Person ReID Dataset.
    Note batch data has shape (batch, seq_len, channel, height, width).
    """
    _sample_methods = ['evenly', 'random', 'all']

    def __init__(self, dataset, seq_len=15, sample_method='evenly', transform=None):
        self.dataset = dataset
        self.seq_len = seq_len
        self.sample_method = sample_method
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        img_paths, pid, camid = self.dataset[index]
        num = len(img_paths)

        if self.sample_method == 'random':
            """
            Randomly sample seq_len items from num items,
            if num is smaller than seq_len, then replicate items
            """
            indices = np.arange(num)
            replace = False if num >= self.seq_len else True
            indices = np.random.choice(indices, size=self.seq_len, replace=replace)
            # sort indices to keep temporal order (comment it to be order-agnostic)
            indices = np.sort(indices)
        
        elif self.sample_method == 'evenly':
            """
            Evenly sample seq_len items from num items.
            """
            if num >= self.seq_len:
                num -= num % self.seq_len
                indices = np.arange(0, num, num/self.seq_len)
            else:
                # if num is smaller than seq_len, simply replicate the last image
                # until the seq_len requirement is satisfied
                indices = np.arange(0, num)
                num_pads = self.seq_len - num
                indices = np.concatenate([indices, np.ones(num_pads).astype(np.int32)*(num-1)])
            assert len(indices) == self.seq_len
    
        elif self.sample_method == 'all':
            """
            Sample all items, seq_len is useless now and batch_size needs to be set to 1.
            """
            indices = np.arange(num)
    
        else:
            raise ValueError('Unknown sample method: {}. Expected one of {}'.format(self.sample_method, self._sample_methods))

        imgs = []
        for index in indices:
            img_path = img_paths[int(index)]
            img = read_image(img_path)
            if self.transform is not None:
                img = self.transform(img)
            img = img.unsqueeze(0)
            imgs.append(img)
        imgs = torch.cat(imgs, dim=0)

        return imgs, pid, camid    

dense 代码

	    elif self.sample == 'dense':
            """
            Sample all frames in a video into a list of clips, each clip contains seq_len frames, batch_size needs to be set to 1.
            This sampling strategy is used in test phase.
            """
            cur_index=0
            frame_indices = range(num)
            indices_list=[]
            while num-cur_index > self.seq_len:
                indices_list.append(frame_indices[cur_index:cur_index+self.seq_len])
                cur_index+=self.seq_len
            last_seq=frame_indices[cur_index:]
            for index in last_seq:
                if len(last_seq) >= self.seq_len:
                    break
                last_seq.append(index)
            indices_list.append(last_seq)
            imgs_list=[]
            for indices in indices_list:
                imgs = []
                for index in indices:
                    index=int(index)
                    img_path = img_paths[index]
                    img = read_image(img_path)
                    if self.transform is not None:
                        img = self.transform(img)
                    #img = img[1:,:,:]  # delete the original z channel, which is all zeros before transform.
                    img = img.unsqueeze(0)
                    imgs.append(img)
                imgs = torch.cat(imgs, dim=0)
                #imgs=imgs.permute(1,0,2,3)
                imgs_list.append(imgs)
            imgs_array = torch.stack(imgs_list)

[1] Multi-scale 3D Convolution Network for Video Based Person Re-Identification
[2] Revisiting Temporal Modeling for Video-based Person ReID

评论

此博客中的热门博文

使用ssh反向代理+shadowsocks实现内网穿透

shadowsocks中转

ubuntu 16.04 reboot命令慢的原因