Ejemplo n.º 1
0
    def __getitem__(self, index):
        left  = self.left[index]
        right = self.right[index]
        if not self.submission:
            disp_L= self.disp_L[index]


        left_img = self.loader(left)
        right_img = self.loader(right)
        if not self.submission:
            dataL = self.dploader(disp_L)
            dataL = np.ascontiguousarray(dataL,dtype=np.float32)
            #Fix psi from -4:10 to 1:15
            # if self.cont:
                # dataL = 23.46 / dataL
                # dataL += 5
                # #Fixing to be more similar to disparity
                # dataL *= 10




        if self.training:  
           w, h = left_img.size
           th, tw = 256, 512
 
           x1 = random.randint(0, w - tw)
           y1 = random.randint(0, h - th)

           left_img = left_img.crop((x1, y1, x1 + tw, y1 + th))
           right_img = right_img.crop((x1, y1, x1 + tw, y1 + th))

           dataL = dataL[y1:y1 + th, x1:x1 + tw]

           processed = preprocess.get_transform(augment=False)  
           left_img   = processed(left_img)
           right_img  = processed(right_img)

           return left_img, right_img, dataL
        else:
           w, h = left_img.size
           left_img = left_img.crop((w-1024, h-416, w, h))
           right_img = right_img.crop((w-1024, h-416, w, h))
           if not self.submission:
                dataL = dataL[12:-12,:]
           processed = preprocess.get_transform(augment=False)
           left_img       = processed(left_img)
           right_img      = processed(right_img)

           # left_img  += torch.Tensor(noisy(left_img))
           # right_img += torch.Tensor(noisy(right_img))

           left_img = torch.clamp(left_img, -1, 1)
           right_img = torch.clamp(right_img, -1, 1)


           if self.submission:
               return left_img,right_img
           else:
               return left_img, right_img, dataL
Ejemplo n.º 2
0
    def __getitem__(self, index):

        left = self.left[index]
        right = self.right[index]
        disp_L = self.disp_L[index]

        assert left.split('/')[-2] == right.split(
            '/')[-2], "L,R is not matched {}".format(left)
        assert left.split('/')[-1][:-9] == right.split(
            '/')[-1][:-10], "L,R is not matched {},{},{},{}".format(
                left.split('/')[-1],
                right.split('/')[-1], left, index)
        assert left.split('/')[-1][:-9] == disp_L.split(
            '/')[-1][:-11], "L,disp is not matched {},{},{},{}".format(
                left.split('/')[-1],
                right.split('/')[-1], left, index)

        left_img = self.loader(left)
        right_img = self.loader(right)
        # dataL, scaleL = self.dploader(disp_L)
        dataL = self.dploader(disp_L)
        try:
            dataL = np.ascontiguousarray(dataL, dtype=np.float32)
        except TypeError:
            print(dataL)
            print(disp_L)

        w, h = left_img.size
        left_img = left_img.crop((0, 244, w, 244 + 592))
        right_img = right_img.crop((0, 244, w, 244 + 592))
        dataL = dataL[
            244:-244, :] / 100.0 * 1920 / 1248  # Ground Truth output coding

        if self.training:
            w, h = left_img.size
            th, tw = 256, 512

            x1 = random.randint(0, w - tw)
            y1 = random.randint(0, h - th)

            left_img = left_img.crop((x1, y1, x1 + tw, y1 + th))
            right_img = right_img.crop((x1, y1, x1 + tw, y1 + th))

            dataL = dataL[y1:y1 + th, x1:x1 + tw]

            processed = preprocess.get_transform(augment=False)
            left_img = processed(left_img)
            right_img = processed(right_img)

            return left_img, right_img, dataL
        else:
            w, h = left_img.size
            left_img = left_img
            right_img = right_img
            processed = preprocess.get_transform(augment=False)
            left_img = processed(left_img)
            right_img = processed(right_img)

            return left_img, right_img, dataL, left
Ejemplo n.º 3
0
    def __getitem__(self, index):
        left = os.path.join(self.left_dir, '{0}.png'.format(self.frame_ids[index]))
        right = os.path.join(self.right_dir, '{0}.png'.format(self.frame_ids[index]))
        # disp_L = os.path.join(self.disp_dir, '{0}.png'.format(self.frame_ids[index]))
        disp_L = os.path.join(self.disp_dir, '{0}.npy'.format(self.frame_ids[index]))
        left_img = Image.open(left).convert('RGB')
        right_img = Image.open(right).convert('RGB')
        # dataL = Image.open(disp_L)
        dataL = np.load(disp_L)

        if self.training:
            w, h = left_img.size
            th, tw = 256, 512

            x1 = random.randint(0, w - tw)
            y1 = random.randint(80, h - th) # only the bottom half have depth measurement

            left_img = left_img.crop((x1, y1, x1 + tw, y1 + th))
            right_img = right_img.crop((x1, y1, x1 + tw, y1 + th))

            disp_img = Image.fromarray(dataL)
            disp_img = disp_img.crop((x1, y1, x1 + tw, y1 + th))
            dataL = np.asarray(disp_img)

            processed = preprocess.get_transform(augment=False)
            left_img   = processed(left_img)
            right_img  = processed(right_img)

            # visualize_disparity(dataL)
            return left_img, right_img, dataL, get_sparse_disp(dataL, erase_ratio=0.8)

        else:
            target_w = 1248
            target_h = 352
            w, h = left_img.size

            # this will add zero padding
            left_img = left_img.crop((w-target_w, h-target_h, w, h))
            right_img = right_img.crop((w-target_w, h-target_h, w, h))
            w1, h1 = left_img.size

            disp_img = Image.fromarray(dataL)
            disp_img = disp_img.crop((w-target_w, h-target_h, w, h))
            dataL = np.asarray(disp_img)

            processed = preprocess.get_transform(augment=False)
            left_img = processed(left_img)
            right_img = processed(right_img)
            # sparse = get_sparse_disp(dataL, erase_ratio=0.9)
            #visualize_disparity(dataL)
            # visualize_disparity(sparse)

            return left_img, right_img, dataL, get_sparse_disp(dataL, erase_ratio=0.8)
Ejemplo n.º 4
0
    def __getitem__(self, index):
        left = self.left[index]
        right = self.right[index]
        disp_L = self.disp_L[index]

        left_img = self.loader(left)
        right_img = self.loader(right)
        dataL = self.dploader(disp_L)

        if self.training:
            w, h = left_img.size
            th, tw = 256, 512

            x1 = random.randint(0, w - tw)
            y1 = random.randint(0, h - th)
            #print("index %d: random x1 is %d, and random y1 is %d with weight: %d; height: %d"%(index,x1,y1,w,h))
            left_img = left_img.crop((x1, y1, x1 + tw, y1 + th))
            right_img = right_img.crop((x1, y1, x1 + tw, y1 + th))

            #dataL = np.ascontiguousarray(dataL,dtype=np.float32)/256
            dataL = np.ascontiguousarray(dataL, dtype=np.float32)
            #print(dataL.shape)
            dataL = dataL[y1:y1 + th, x1:x1 + tw]
            #dataL = dataL[x1:x1 + tw, y1:y1 + th]
            #print(dataL.shape)

            processed = preprocess.get_transform(augment=False)
            left_img = processed(left_img)
            right_img = processed(right_img)

            return left_img, right_img, dataL
        else:
            #print "else is also ran"
            pre_w = 1226
            pre_h = 370
            w, h = left_img.size

            left_img = left_img.crop((w - 1232, h - 368, w, h))
            right_img = right_img.crop((w - 1232, h - 368, w, h))
            w1, h1 = left_img.size

            #dataL = dataL.crop((w-1232, h-368, w, h))
            dataL = dataL[h - pre_h:h - pre_h + h, w - pre_w:w - pre_w + w]
            #dataL = np.ascontiguousarray(dataL,dtype=np.float32)/256
            dataL = np.ascontiguousarray(dataL, dtype=np.float32)

            processed = preprocess.get_transform(augment=False)
            left_img = processed(left_img)
            right_img = processed(right_img)

            return left_img, right_img, dataL
Ejemplo n.º 5
0
    def __getitem__(self, index):
        left = self.left[index]
        right = self.right[index]
        disp_L = self.disp_L[index]

        left_img = self.loader(left)
        right_img = self.loader(right)
        dataL = self.dploader(disp_L)

        if self.training:
            w, h = left_img.size
            #    th, tw = 256, 512
            th, tw = 368, 1232

            #    x1 = random.randint(0, w - tw)
            #    y1 = random.randint(0, h - th)

            #org
            #    left_img = left_img.crop((x1, y1, x1 + tw, y1 + th))
            #    right_img = right_img.crop((x1, y1, x1 + tw, y1 + th))

            left_img = left_img.crop((w - tw, h - th, w, h))
            right_img = right_img.crop((w - tw, h - th, w, h))

            dataL = dataL.crop((w - tw, h - th, w, h))

            dataL = np.ascontiguousarray(dataL, dtype=np.float32) / 256
            #    dataL = dataL[y1:y1 + th, x1:x1 + tw]

            processed = preprocess.get_transform(augment=False)
            left_img = transforms.ToTensor()(left_img)
            right_img = transforms.ToTensor()(right_img)

            return left_img, right_img, dataL
        else:
            w, h = left_img.size

            left_img = left_img.crop((w - 1232, h - 368, w, h))
            right_img = right_img.crop((w - 1232, h - 368, w, h))
            w1, h1 = left_img.size

            dataL = dataL.crop((w - 1232, h - 368, w, h))
            dataL = np.ascontiguousarray(dataL, dtype=np.float32) / 256

            processed = preprocess.get_transform(augment=False)

            left_img = transforms.ToTensor()(left_img)
            right_img = transforms.ToTensor()(right_img)

            return left_img, right_img, dataL
Ejemplo n.º 6
0
    def __getitem__(self, index):
        left = self.left[index]
        normal = self.normal[index]
        gt = self.gts[index]
        left_img = self.loader(left)
        w, h = left_img.size
        input1, mask1 = self.inloader(gt)  # 3channel/mask
        sparse, mask = self.sloader(normal)

        th, tw = 256, 512
        x1 = random.randint(0, w - tw)
        y1 = random.randint(0, h - th)

        left_img = left_img.crop((x1, y1, x1 + tw, y1 + th))
        data_in1 = input1[y1:y1 + th, x1:x1 + tw, :]
        sparse_n = sparse[y1:y1 + th, x1:x1 + tw, :]
        mask = mask[y1:y1 + th, x1:x1 + tw, :]
        mask1 = mask1[y1:y1 + th, x1:x1 + tw, :]

        # transform to tensor
        processed = preprocess.get_transform(
            augment=False)  # convert numpy(H, W, C) to FloatTensor(C x H x W)
        # processed = scale_crop2()
        left_img = processed(left_img)
        sparse_n = processed(sparse_n)
        # print("------------:", sparse_n.shape) # (1,256,512)
        # print("************:", mask.shape) # (256,512,1)
        return left_img, sparse_n, mask, mask1, data_in1
Ejemplo n.º 7
0
    def __getitem__(self, index):
        left = self.left[index]
        input = self.input[index]
        sparse = self.sparse[index]
        left_img = self.loader(left)

        index_str = self.left[index].split('/')[-4][0:10]
        params_t = INSTICS[index_str]
        params = np.ones((256, 512, 3), dtype=np.float32)
        params[:, :, 0] = params[:, :, 0] * params_t[0]
        params[:, :, 1] = params[:, :, 1] * params_t[1]
        params[:, :, 2] = params[:, :, 2] * params_t[2]

        h, w, c = left_img.shape
        input1 = self.inloader(input)
        sparse, mask = self.sloader(sparse)

        th, tw = 256, 512
        x1 = random.randint(0, w - tw)
        y1 = random.randint(0, h - th)
        mask = np.reshape(mask, [sparse.shape[0], sparse.shape[1], 1]).astype(
            np.float32)
        params = np.reshape(params, [256, 512, 3]).astype(np.float32)

        left_img = left_img[y1:y1 + th, x1:x1 + tw, :]
        data_in1 = input1[y1:y1 + th, x1:x1 + tw, :]
        sparse = sparse[y1:y1 + th, x1:x1 + tw, :]
        mask = mask[y1:y1 + th, x1:x1 + tw, :]
        processed = preprocess.get_transform(augment=False)

        left_img = processed(left_img)
        sparse = processed(sparse)
        mask = processed(mask)

        return left_img, data_in1, sparse, mask, params
Ejemplo n.º 8
0
    def get_loader(self, force_update=False):
        if force_update or self.regime.update(self.epoch, self.steps):
            setting = self.get_setting()
            self._transform = get_transform(**setting['transform'])
            setting['data'].setdefault('transform', self._transform)
            self._data = get_dataset(**setting['data'])
            if setting['other'].get('distributed', False):
                setting['loader']['sampler'] = DistributedSampler(self._data)
                setting['loader']['shuffle'] = None
                # pin-memory currently broken for distributed
                setting['loader']['pin_memory'] = False
            if setting['other'].get('duplicates', 0) > 1:
                setting['loader']['shuffle'] = None
                sampler = setting['loader'].get(
                    'sampler', RandomSampler(self._data))
                setting['loader']['sampler'] = DuplicateBatchSampler(sampler, setting['loader']['batch_size'],
                                                                     duplicates=setting['other']['duplicates'],
                                                                     drop_last=setting['loader'].get('drop_last', False))

            self._sampler = setting['loader'].get('sampler', None)
            self._loader = torch.utils.data.DataLoader(
                self._data, **setting['loader'])
            if setting['other'].get('duplicates', 0) > 1:
                self._loader.batch_sampler = self._sampler
        return self._loader
Ejemplo n.º 9
0
    def __getitem__(self, index):
        left  = self.left[index]
        right = self.right[index]
        disp_L= self.disp_L[index]

        left_img = Image.open(left).convert('RGB')
        right_img = Image.open(right).convert('RGB')
        dataL = Image.open(disp_L)
        guideL = Image.open(self.guide[index])

        w, h = left_img.size

        left_img = left_img.crop((w-1280, h-384, w, h))
        right_img = right_img.crop((w-1280, h-384, w, h))
        w1, h1 = left_img.size

        dataL = dataL.crop((w-1280, h-384, w, h))
        dataL = np.ascontiguousarray(dataL,dtype=np.float32)/256

        guideL = guideL.crop((w-1280, h-384, w, h))
        guideL = np.ascontiguousarray(guideL,dtype=np.float32)/256

        processed = preprocess.get_transform(augment=False)  
        rawimage = preprocess.identity(256)

        reference = rawimage(left_img)
        left_img       = processed(left_img)
        right_img      = processed(right_img)
          
        return reference, left_img, right_img, guideL, dataL, h, w
Ejemplo n.º 10
0
    def __getitem__(self, index):
        h5p = self.h5path[index]
        # h5p: /nfs-data/zhengk_data/kitti_hdf5/train/2011_09_26_drive_0009_sync_image_02_0000000002.h5
        left_img, sparse_n, mask, gtdepth = self.loader(h5p)  # read h5

        # index_str = self.left[index].split('/')[-4][0:10]
        index_str = h5p.split('/')[-1][0:10]
        params_t = INSTICS[index_str]
        # print(params_t)
        params = np.ones((256, 512, 3), dtype=np.float32)
        params[:, :, 0] = params[:, :, 0] * params_t[0]
        params[:, :, 1] = params[:, :, 1] * params_t[1]
        params[:, :, 2] = params[:, :, 2] * params_t[2]

        params = np.reshape(params, [256, 512, 3]).astype(np.float32)

        # convert array into tensor
        processed = preprocess.get_transform(
            augment=False)  # convert numpy(H, W, C) to FloatTensor(C x H x W)
        # processed = scale_crop2()
        left_img = processed(left_img)
        sparse_n = processed(sparse_n)
        mask = processed(mask)
        # print("left: ", left_img.shape) # (3,256,512)
        # print("------------:", sparse_n.shape) # (1,256,512)
        # print("************:", mask.shape) # (256,512,1)
        return left_img, gtdepth, sparse_n, mask, params
Ejemplo n.º 11
0
    def __getitem__(self, index):
        left = self.left[index]
        right = self.right[index]
        disp_L = self.disp_L[index]

        left_img = self.loader(left)
        right_img = self.loader(right)
        dataL = self.dploader(disp_L)

        if self.training:
            w, h = left_img.size
            th, tw = 256, 512

            x1 = random.randint(0, w - tw)
            y1 = random.randint(
                80, h - th)  # only the bottom half have depth measurement

            left_img = left_img.crop((x1, y1, x1 + tw, y1 + th))
            right_img = right_img.crop((x1, y1, x1 + tw, y1 + th))

            dataL = np.ascontiguousarray(dataL, dtype=np.float32) / 256
            dataL = dataL[y1:y1 + th, x1:x1 + tw]

            processed = preprocess.get_transform(augment=False)
            left_img = processed(left_img)
            right_img = processed(right_img)

            erase_ratio = random.uniform(0.9, 1)
            return left_img, right_img, dataL, get_sparse_disp(
                dataL, erase_ratio=erase_ratio)
        else:
            w, h = left_img.size

            left_img = left_img.crop((w - 1248, h - 352, w, h))
            right_img = right_img.crop((w - 1248, h - 352, w, h))
            w1, h1 = left_img.size

            dataL = dataL.crop((w - 1248, h - 352, w, h))
            dataL = np.ascontiguousarray(dataL, dtype=np.float32) / 256

            processed = preprocess.get_transform(augment=False)
            left_img = processed(left_img)
            right_img = processed(right_img)

            return left_img, right_img, dataL, get_sparse_disp(
                dataL, erase_ratio=0.95)
Ejemplo n.º 12
0
    def __getitem__(self, index):
        left = self.left[index]
        right = self.right[index]
        disp_L = self.disp_L[index]
        calibpath = self.calib[index]

        left_img = self.loader(left)
        right_img = self.loader(right)
        calib = calib_loader(calibpath)
        dataL = self.dploader(disp_L)

        if self.training:
            w, h = left_img.size
            th, tw = 256, 512

            x1 = random.randint(0, w - tw)
            y1 = random.randint(0, h - th)

            left_img = left_img.crop((x1, y1, x1 + tw, y1 + th))
            right_img = right_img.crop((x1, y1, x1 + tw, y1 + th))

            dataL = dataL[y1:y1 + th, x1:x1 + tw]

            processed = preprocess.get_transform(augment=False)
            left_img = processed(left_img)
            right_img = processed(right_img)

        else:
            w, h = left_img.size

            # left_img = left_img.crop((w - 1232, h - 368, w, h))
            # right_img = right_img.crop((w - 1232, h - 368, w, h))
            left_img = left_img.crop((w - 1200, h - 352, w, h))
            right_img = right_img.crop((w - 1200, h - 352, w, h))
            w1, h1 = left_img.size

            # dataL1 = dataL[h - 368:h, w - 1232:w]
            dataL = dataL[h - 352:h, w - 1200:w]

            processed = preprocess.get_transform(augment=False)
            left_img = processed(left_img)
            right_img = processed(right_img)

        dataL = torch.from_numpy(dataL).float()
        calib = torch.tensor(calib).float()
        return left_img, right_img, dataL, calib
Ejemplo n.º 13
0
    def __getitem__(self, index):
        left = self.left[index]
        right = self.right[index]
        disp_L = self.disp_L[index]

        left_img = self.loader(left)
        right_img = self.loader(right)
        dataL, scaleL = self.dploader(disp_L)

        if self.training:
            w, h = left_img.size
            th, tw = 256, 512

            x1 = random.randint(0, w - tw)
            y1 = random.randint(0, h - th)

            left_img = left_img.crop((x1, y1, x1 + tw, y1 + th))
            right_img = right_img.crop((x1, y1, x1 + tw, y1 + th))

            dataL = np.ascontiguousarray(dataL, dtype=np.float32)
            dataL = dataL[y1:y1 + th, x1:x1 + tw]

            processed = preprocess.get_transform(augment=False)
            left_img = processed(left_img)
            right_img = processed(right_img)

            return left_img, right_img, dataL, get_sparse_disp(dataL,
                                                               erase_ratio=0.8)
        else:
            w, h = left_img.size
            left_img = left_img.crop((w - 960, h - 544, w, h))
            right_img = right_img.crop((w - 960, h - 544, w, h))
            processed = preprocess.get_transform(augment=False)
            left_img = processed(left_img)
            right_img = processed(right_img)

            dataL = np.pad(dataL, ((max(544 - h, 0), 0), (max(960 - w, 0), 0)),
                           'constant',
                           constant_values=0)
            dataL = dataL[max(dataL.shape[0] - 544, 0):dataL.shape[0],
                          max(dataL.shape[1] - 960, 0):dataL.shape[1]]

            return left_img, right_img, dataL, get_sparse_disp(dataL,
                                                               erase_ratio=0.8)
Ejemplo n.º 14
0
    def __getitem__(self, index):
        left  = self.left[index]
        right = self.right[index]
        disp_L= self.disp_L[index]
	left_img = self.loader(left)
        right_img = self.loader(right)
        disp_img = self.loader_g(disp_L)
	processed = preprocess.get_transform(augment=False)  

        if self.training:  
           w, h = left_img.size
           th, tw = 256, 512
 
           x1 = random.randint(0, w - tw)
           y1 = random.randint(0, h - th)

           left_img = left_img.crop((x1, y1, x1 + tw, y1 + th))
           right_img = right_img.crop((x1, y1, x1 + tw, y1 + th))
	   disp_img = np.ascontiguousarray(disp_img,dtype=np.float32)
           disp_img = disp_img[y1:y1 + th, x1:x1 + tw]


           processed = preprocess.get_transform(augment=False)  
           left_img   = processed(left_img)
           right_img  = processed(right_img)

           return left_img, right_img, disp_img

        else:
	   w, h = left_img.size

           left_img = left_img.crop((0, 0, 1024,1024))
           right_img = right_img.crop((0,0,1024,1024))
           disp_img = np.ascontiguousarray(disp_img,dtype=np.float32)
          


           processed = preprocess.get_transform(augment=False)
           left_img   = processed(left_img)
           right_img  = processed(right_img)
	   disp_img = disp_img[0:1024,0:1024]
           return left_img, right_img, disp_img
    def __getitem__(self, index):
        left = self.left[index]
        right = self.right[index]
        disp_L = self.disp_L[index]

        # left_img = self.loader(left)
        # right_img = self.loader(right)
        # dataL, scaleL = self.dploader(disp_L)
        left_img = self.loader(os.path.join(self.datapath, left))
        right_img = self.loader(os.path.join(self.datapath, right))
        # print(os.path.join(self.datapath,left))
        # print(os.path.join(self.datapath,disp_L))
        dataL, _ = self.dploader(os.path.join(self.datapath, disp_L))
        #print("!!!!",left_img)
        dataL = np.ascontiguousarray(dataL, dtype=np.float32)

        if self.training:
            w, h = left_img.size
            th, tw = 256, 512

            x1 = random.randint(0, w - tw)
            y1 = random.randint(0, h - th)

            left_img = left_img.crop((x1, y1, x1 + tw, y1 + th))
            right_img = right_img.crop((x1, y1, x1 + tw, y1 + th))

            dataL = dataL[y1:y1 + th, x1:x1 + tw]

            processed = preprocess.get_transform(augment=False)
            left_img = processed(left_img)
            right_img = processed(right_img)

            return left_img, right_img, dataL
        else:
            w, h = left_img.size
            left_img = left_img.crop((w - 960, h - 544, w, h))
            right_img = right_img.crop((w - 960, h - 544, w, h))
            processed = preprocess.get_transform(augment=False)
            left_img = processed(left_img)
            right_img = processed(right_img)

            return left_img, right_img, dataL
Ejemplo n.º 16
0
    def __getitem__(self, index):
        img = self.img[index]
        lr = self.lr[index]

        if self.hr is not None:
            hr = self.hr[index]
            hr_ = self.dploader(hr)

        img_ = self.loader(img)
        lr_ = self.dploader(lr)

        if self.training:
            w, h = img_.size
            th, tw = 256, 512

            x1 = random.randint(0, w - tw)
            y1 = random.randint(0, h - th)

            img_ = img_.crop((x1, y1, x1 + tw, y1 + th))
            lr_ = np.ascontiguousarray(lr_, dtype=np.float32)

            lr_ = lr_[y1:y1 + th, x1:x1 + tw]
            hr_ = np.ascontiguousarray(hr_, dtype=np.float32)
            hr_ = hr_[y1:y1 + th, x1:x1 + tw]

            processed = preprocess.get_transform(augment=False)
            img_ = processed(img_)
            return img_, lr_, hr_
        elif self.hr is not None:
            w, h = img_.size
            img_ = img_.crop((w - 960, h - 544, w, h))

            processed = preprocess.get_transform(augment=False)
            img_ = processed(img_)
            return img_, lr_, hr_
        else:
            lr_ = np.ascontiguousarray(lr_, dtype=np.float32)
            processed = preprocess.get_transform(augment=False)
            img_ = processed(img_)
            return img_, lr_
Ejemplo n.º 17
0
    def __getitem__(self, index):
        left  = self.left[index]
        right = self.right[index]
        disp_L= self.disp_L[index]


        left_img = self.loader(left)
        right_img = self.loader(right)
        dataL, scaleL = self.dploader(disp_L)
        dataL = np.ascontiguousarray(dataL,dtype=np.float32)


        #training为TURE:说明数据为训练图片;FALSE:说明数据为测试图片(用来控制照片的大小)
        if self.training:  
           w, h = left_img.size
           th, tw = 256, 512
 
           x1 = random.randint(0, w - tw)#random(a,b):用于生成随机数,a<=n<=b
           y1 = random.randint(0, h - th)

           left_img = left_img.crop((x1, y1, x1 + tw, y1 + th))#crop(box):返回当前图像的裁剪副本,四个参数从左开始顺时针定义坐标
           right_img = right_img.crop((x1, y1, x1 + tw, y1 + th))

           dataL = dataL[y1:y1 + th, x1:x1 + tw]#裁剪标签数据

           processed = preprocess.get_transform(augment=False)  
           left_img   = processed(left_img)
           right_img  = processed(right_img)

           return left_img, right_img, dataL
        else:
           w, h = left_img.size
           left_img = left_img.crop((w-960, h-544, w, h))
           right_img = right_img.crop((w-960, h-544, w, h))
           processed = preprocess.get_transform(augment=False)  
           left_img       = processed(left_img)
           right_img      = processed(right_img)

           return left_img, right_img, dataL
Ejemplo n.º 18
0
    def __getitem__(self, index):
        up = self.up[index]
        down = self.down[index]
        disp_name = self.disp_name[index]
        equi_info = self.equi_infos

        up_img = self.loader(up)
        down_img = self.loader(down)
        disp = self.dploader(disp_name)
        up_img = np.concatenate([np.array(up_img), equi_info], 2)
        down_img = np.concatenate([np.array(down_img), equi_info], 2)

        if self.training:
            h, w = up_img.shape[0], up_img.shape[1]
            th, tw = 512, 256

            # vertical remaining cropping
            x1 = random.randint(0, w - tw)
            y1 = random.randint(0, h - th)
            up_img = up_img[y1:y1 + th, x1:x1 + tw, :]
            down_img = down_img[y1:y1 + th, x1:x1 + tw, :]
            disp = np.ascontiguousarray(disp, dtype=np.float32)
            disp = disp[y1:y1 + th, x1:x1 + tw]

            # preprocessing
            processed = preprocess.get_transform(augment=False)
            up_img = processed(up_img)
            down_img = processed(down_img)

            return up_img, down_img, disp
        else:
            disp = np.ascontiguousarray(disp, dtype=np.float32)

            processed = preprocess.get_transform(augment=False)
            up_img = processed(up_img)
            down_img = processed(down_img)

            return up_img, down_img, disp
Ejemplo n.º 19
0
    def __getitem__(self, index):
        h5p = self.h5path[index]
        left_img, sparse_n, mask, mask1, normal = self.loader(h5p)  # read h5

        # convert array into tensor
        processed = preprocess.get_transform(
            augment=False)  # convert numpy(H, W, C) to FloatTensor(C x H x W)
        # processed = scale_crop2()
        left_img = processed(left_img)
        sparse_n = processed(sparse_n)
        # print("left: ", left_img.shape) # (3,256,512)
        # print("------------:", sparse_n.shape) # (1,256,512)
        # print("************:", mask.shape) # (256,512,1)
        return left_img, sparse_n, mask, mask1, normal
Ejemplo n.º 20
0
 def get_loader(self, force_update=False):
     if force_update or self.regime.update(self.epoch, self.steps):
         setting = self.get_setting()
         self._transform = get_transform(**setting['transform'])
         setting['data'].setdefault('transform', self._transform)
         self._data = get_dataset(**setting['data'])
         if setting['other'].get('distributed', False):
             setting['loader']['sampler'] = DistributedSampler(self._data)
             setting['loader']['shuffle'] = None
             # pin-memory currently broken for distributed
             setting['loader']['pin_memory'] = False
         self._sampler = setting['loader'].get('sampler', None)
         self._loader = torch.utils.data.DataLoader(self._data,
                                                    **setting['loader'])
     return self._loader
Ejemplo n.º 21
0
    def __getitem__(self, index):
        left = self.left[index]
        right = self.right[index]

        left_img = self.loader(left)
        right_img = self.loader(right)

        left_img_flip = np.fliplr(left_img)
        right_img_flip = np.fliplr(right_img)

        if self.training:
            w, h = left_img.size
            processed = preprocess.get_transform(augment=False)
            left_img = processed(left_img)
            right_img = processed(right_img)
            return left_img, right_img
Ejemplo n.º 22
0
    def __getitem__(self, index):
        left = self.left[index]
        right = self.right[index]
        disp_L = self.disp_L[index]

        left_img = self.loader(
            left
        )  # TODO: converting to grayscale may cause issues with first layer of the net. should verify.
        right_img = self.loader(right)
        dataL = self.dploader(disp_L)

        w, h = left_img.size

        dataL = ((np.ascontiguousarray(dataL, dtype=np.float32) - 2**15) /
                 2**8)

        processed = preprocess.get_transform(augment=False)
        left_img = processed(left_img)  # make this image a tensor
        right_img = processed(right_img)  # make this image a tensor

        return left_img, right_img, dataL
Ejemplo n.º 23
0
    def __getitem__(self, index):
        left = self.left[index]
        normal = self.normal[index]
        gt = self.gts[index]
        left_img = self.loader(left)
        w, h = left_img.size
        input1, mask1 = self.inloader(gt)
        sparse, mask = self.sloader(normal)

        th, tw = 256, 512
        x1 = random.randint(0, w - tw)
        y1 = random.randint(0, h - th)

        left_img = left_img.crop((x1, y1, x1 + tw, y1 + th))
        data_in1 = input1[y1:y1 + th, x1:x1 + tw, :]
        sparse_n = sparse[y1:y1 + th, x1:x1 + tw, :]
        mask = mask[y1:y1 + th, x1:x1 + tw, :]
        mask1 = mask1[y1:y1 + th, x1:x1 + tw, :]

        processed = preprocess.get_transform(augment=False)
        # processed = scale_crop2()
        left_img = processed(left_img)
        sparse_n = processed(sparse_n)
        return left_img, sparse_n, mask, mask1, data_in1
Ejemplo n.º 24
0
def main():
    global args, best_prec1, dtype
    best_prec1 = 0
    args = parser.parse_args()
    dtype = torch_dtypes.get(args.dtype)
    torch.manual_seed(args.seed)
    time_stamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    if args.evaluate:
        args.results_dir = '/tmp'
    if args.save is '':
        args.save = time_stamp
    save_path = os.path.join(args.results_dir, args.save)
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    setup_logging(os.path.join(save_path, 'log.txt'),
                  resume=args.resume is not '')
    results_path = os.path.join(save_path, 'results')
    results = ResultsLog(results_path,
                         title='Training Results - %s' % args.save)

    logging.info("saving to %s", save_path)
    logging.debug("run arguments: %s", args)

    if 'cuda' in args.device and torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)
        torch.cuda.set_device(args.device_ids[0])
        cudnn.benchmark = True
    else:
        args.device_ids = None

    # create model
    logging.info("creating model %s", args.model)
    model = models.__dict__[args.model]
    model_config = {'input_size': args.input_size, 'dataset': args.dataset}

    if args.model_config is not '':
        model_config = dict(model_config, **literal_eval(args.model_config))

    model = model(**model_config)
    logging.info("created model with configuration: %s", model_config)

    # optionally resume from a checkpoint
    if args.evaluate:
        if not os.path.isfile(args.evaluate):
            parser.error('invalid checkpoint: {}'.format(args.evaluate))
        checkpoint = torch.load(args.evaluate)
        model.load_state_dict(checkpoint['state_dict'])
        logging.info("loaded checkpoint '%s' (epoch %s)", args.evaluate,
                     checkpoint['epoch'])
    elif args.resume:
        checkpoint_file = args.resume
        if os.path.isdir(checkpoint_file):
            results.load(os.path.join(checkpoint_file, 'results.csv'))
            checkpoint_file = os.path.join(checkpoint_file,
                                           'model_best.pth.tar')
        if os.path.isfile(checkpoint_file):
            logging.info("loading checkpoint '%s'", args.resume)
            checkpoint = torch.load(checkpoint_file)
            args.start_epoch = checkpoint['epoch'] - 1
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            logging.info("loaded checkpoint '%s' (epoch %s)", checkpoint_file,
                         checkpoint['epoch'])
        else:
            logging.error("no checkpoint found at '%s'", args.resume)

    num_parameters = sum([l.nelement() for l in model.parameters()])
    logging.info("number of parameters: %d", num_parameters)

    # Data loading code
    default_transform = {
        'train':
        get_transform(args.dataset, input_size=args.input_size, augment=True),
        'eval':
        get_transform(args.dataset, input_size=args.input_size, augment=False)
    }
    transform = getattr(model, 'input_transform', default_transform)
    regime = getattr(model, 'regime', [{
        'epoch': 0,
        'optimizer': args.optimizer,
        'lr': args.lr,
        'momentum': args.momentum,
        'weight_decay': args.weight_decay
    }])

    # define loss function (criterion) and optimizer
    criterion = getattr(model, 'criterion', nn.CrossEntropyLoss)()
    criterion.to(args.device, dtype)
    model.to(args.device, dtype)

    val_data = get_dataset(args.dataset, 'val', transform['eval'])
    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    train_data = get_dataset(args.dataset, 'train', transform['train'])
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)

    optimizer = OptimRegime(model.parameters(), regime)
    logging.info('training regime: %s', regime)

    for epoch in range(args.start_epoch, args.epochs):
        # train for one epoch
        train_loss, train_prec1, train_prec5 = train(train_loader, model,
                                                     criterion, epoch,
                                                     optimizer)

        # evaluate on validation set
        val_loss, val_prec1, val_prec5 = validate(val_loader, model, criterion,
                                                  epoch)

        # remember best prec@1 and save checkpoint
        is_best = val_prec1 > best_prec1
        best_prec1 = max(val_prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': args.model,
                'config': args.model_config,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'regime': regime
            },
            is_best,
            path=save_path)
        logging.info('\n Epoch: {0}\t'
                     'Training Loss {train_loss:.4f} \t'
                     'Training Prec@1 {train_prec1:.3f} \t'
                     'Training Prec@5 {train_prec5:.3f} \t'
                     'Validation Loss {val_loss:.4f} \t'
                     'Validation Prec@1 {val_prec1:.3f} \t'
                     'Validation Prec@5 {val_prec5:.3f} \n'.format(
                         epoch + 1,
                         train_loss=train_loss,
                         val_loss=val_loss,
                         train_prec1=train_prec1,
                         val_prec1=val_prec1,
                         train_prec5=train_prec5,
                         val_prec5=val_prec5))

        results.add(epoch=epoch + 1,
                    train_loss=train_loss,
                    val_loss=val_loss,
                    train_error1=100 - train_prec1,
                    val_error1=100 - val_prec1,
                    train_error5=100 - train_prec5,
                    val_error5=100 - val_prec5)
        results.plot(x='epoch',
                     y=['train_loss', 'val_loss'],
                     legend=['training', 'validation'],
                     title='Loss',
                     ylabel='loss')
        results.plot(x='epoch',
                     y=['train_error1', 'val_error1'],
                     legend=['training', 'validation'],
                     title='Error@1',
                     ylabel='error %')
        results.plot(x='epoch',
                     y=['train_error5', 'val_error5'],
                     legend=['training', 'validation'],
                     title='Error@5',
                     ylabel='error %')
        results.save()
Ejemplo n.º 25
0
def main():
    global args, best_prec1
    best_prec1 = 0
    args = parser.parse_args()
    #import pdb; pdb.set_trace()
    #torch.save(args.batch_size/(len(args.gpus)/2+1),'multi_gpu_batch_size')
    if args.evaluate:
        args.results_dir = 'tmp-hinge/'
        if not os.path.exists(args.results_dir):
            os.mkdir(args.results_dir)
    if args.save is '':
        args.save = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    save_path = os.path.join(args.results_dir, args.save)
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    setup_logging(os.path.join(save_path, 'log.txt'))
    results_file = os.path.join(save_path, 'results.%s')
    results = ResultsLog(results_file % 'csv', results_file % 'html')

    logging.info("saving to %s", save_path)
    logging.debug("run arguments: %s", args)

    if 'cuda' in args.type:
        args.gpus = [int(i) for i in args.gpus.split(',')]
        torch.cuda.set_device(args.gpus[0])
        cudnn.benchmark = True
    else:
        args.gpus = None

    # create model
    logging.info("creating model %s", args.model)
    model = models.__dict__[args.model]
    model_config = {'input_size': args.input_size, 'dataset': args.dataset}

    if args.model_config is not '':
        model_config = dict(model_config, **literal_eval(args.model_config))

    model = model(**model_config)
    logging.info("created model with configuration: %s", model_config)

    # optionally resume from a checkpoint
    if args.evaluate:
        if not os.path.isfile(args.evaluate):
            parser.error('invalid checkpoint: {}'.format(args.evaluate))
        checkpoint = torch.load(args.evaluate)
        model.load_state_dict(checkpoint['state_dict'])
        logging.info("loaded checkpoint '%s' (epoch %s)", args.evaluate,
                     checkpoint['epoch'])
    elif args.resume:
        checkpoint_file = args.resume
        if os.path.isdir(checkpoint_file):
            results.load(os.path.join(checkpoint_file, 'results.csv'))
            checkpoint_file = os.path.join(checkpoint_file,
                                           'model_best.pth.tar')
        if os.path.isfile(checkpoint_file):
            logging.info("loading checkpoint '%s'", args.resume)
            checkpoint = torch.load(checkpoint_file)
            args.start_epoch = checkpoint['epoch'] - 1
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            logging.info("loaded checkpoint '%s' (epoch %s)", checkpoint_file,
                         checkpoint['epoch'])
        else:
            logging.error("no checkpoint found at '%s'", args.resume)

    num_parameters = sum([l.nelement() for l in model.parameters()])
    logging.info("number of parameters: %d", num_parameters)

    # Data loading code
    default_transform = {
        'train':
        get_transform(args.dataset, input_size=args.input_size, augment=True),
        'eval':
        get_transform(args.dataset, input_size=args.input_size, augment=False)
    }
    transform = getattr(model, 'input_transform', default_transform)
    regime = getattr(
        model, 'regime', {
            0: {
                'optimizer': args.optimizer,
                'lr': args.lr,
                'momentum': args.momentum,
                'weight_decay': args.weight_decay
            }
        })
    # define loss function (criterion) and optimizer
    # CrossEntropyLoss()=log_softmax() + NLLLoss() 
    #criterion = getattr(model, 'criterion', nn.NLLLoss)()
    criterion = getattr(model, 'criterion', HingeLoss)()
    #criterion.type(args.type)
    model.type(args.type)

    val_data = get_dataset(args.dataset, 'val', transform['eval'])
    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    train_data = get_dataset(args.dataset, 'train', transform['train'])
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
    logging.info('training regime: %s', regime)
    #import pdb; pdb.set_trace()
    # 不明白为什么要加这个函数,并且项目中也未找到该函数
    # search_binarized_modules(model)

    for epoch in range(args.start_epoch, args.epochs):
        optimizer = adjust_optimizer(optimizer, epoch, regime)

        # train for one epoch
        train_loss, train_prec1, train_prec5 = train(train_loader, model,
                                                     criterion, epoch,
                                                     optimizer)

        # evaluate on validation set
        val_loss, val_prec1, val_prec5 = validate(val_loader, model, criterion,
                                                  epoch)

        # remember best prec@1 and save checkpoint
        is_best = val_prec1 > best_prec1
        best_prec1 = max(val_prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': args.model,
                'config': args.model_config,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'regime': regime
            },
            is_best,
            path=save_path)
        logging.info('\n Epoch: {0}\t'
                     'Training Loss {train_loss:.4f} \t'
                     'Training Prec@1 {train_prec1:.3f} \t'
                     'Training Prec@5 {train_prec5:.3f} \t'
                     'Validation Loss {val_loss:.4f} \t'
                     'Validation Prec@1 {val_prec1:.3f} \t'
                     'Validation Prec@5 {val_prec5:.3f} \n'.format(
                         epoch + 1,
                         train_loss=train_loss,
                         val_loss=val_loss,
                         train_prec1=train_prec1,
                         val_prec1=val_prec1,
                         train_prec5=train_prec5,
                         val_prec5=val_prec5))

        results.add(epoch=epoch + 1,
                    train_loss=train_loss,
                    val_loss=val_loss,
                    train_error1=100 - train_prec1,
                    val_error1=100 - val_prec1,
                    train_error5=100 - train_prec5,
                    val_error5=100 - val_prec5)
        results.plot(x='epoch',
                     y=['train_loss', 'val_loss'],
                     title='Loss',
                     ylabel='loss')
        results.plot(x='epoch',
                     y=['train_error1', 'val_error1'],
                     title='Error@1',
                     ylabel='error %')
        results.plot(x='epoch',
                     y=['train_error5', 'val_error5'],
                     title='Error@5',
                     ylabel='error %')
        results.save()
Ejemplo n.º 26
0
def main():
    hvd.init()
    size = hvd.size()
    local_rank = hvd.local_rank()

    torch.manual_seed(123 + hvd.rank())
    global args, best_prec1
    best_prec1 = 0
    args = parser.parse_args()

    if args.pruning_mode == 1:
        print("thd mode")
        from hvd_utils.DGCoptimizer_thd import DGCDistributedOptimizer
    #elif args.pruning_mode == 2:
    #    print("chunck mode")
    #    from hvd_utils.DGCoptimizer_chunck import DGCDistributedOptimizer
    #elif args.pruning_mode == 3:
    #    print("topk mode")
    #    from hvd_utils.DGCoptimizer import DGCDistributedOptimizer
    #elif args.pruning_mode == 6:
    #    print("seperate mode")
    #    from hvd_utils.DGCoptimizer_thd_sep import DGCDistributedOptimizer
    #elif args.pruning_mode == 7:
    #    print("topk quant mode")
    #    from hvd_utils.DGCoptimizer_quant import DGCDistributedOptimizer
    #elif args.pruning_mode == 8:
    #    print("topk quant mode")
    #    from hvd_utils.DGCoptimizer_thd_quant import DGCDistributedOptimizer
    elif args.pruning_mode == 10:
        print("hybrid mode")
        from hvd_utils.DGCoptimizer_hybrid import DGCDistributedOptimizer
    elif args.pruning_mode == 11:
        print("hybrid quant mode")
        from hvd_utils.DGCoptimizer_hybrid_quant import DGCDistributedOptimizer
    elif args.pruning_mode == 12:
        print("hybrid v2 quant mode")
        from hvd_utils.DGCoptimizer_hybrid_quantv2 import DGCDistributedOptimizer
    elif args.pruning_mode == 13:
        print("hybrid v2 mode")
        from hvd_utils.DGCoptimizer_hybridv2 import DGCDistributedOptimizer
    else:
        print("pruning_mode should be set correctly")
        exit(0)
    from hvd_utils.DGCoptimizer_commoverlap import myhvdOptimizer

    if args.evaluate:
        args.results_dir = '/tmp'
    if args.save is '':
        args.save = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    save_path = os.path.join(args.results_dir, args.save)
    if not os.path.exists(save_path):
        if hvd.rank() == 0:
            os.makedirs(save_path)
        else:
            time.sleep(1)

    if hvd.rank() == 0:
        setup_logging(os.path.join(save_path, 'log.txt'))
        results_file = os.path.join(save_path, 'results.%s')
        results = ResultsLog(results_file % 'csv', results_file % 'html')

    if hvd.rank() == 0:
        logging.info("saving to %s", save_path)
        logging.debug("run arguments: %s", args)

    if 'cuda' in args.type:
        torch.cuda.manual_seed(123 + hvd.rank())
        args.gpus = [int(i) for i in args.gpus.split(',')]

        if args.use_cluster:
            torch.cuda.set_device(hvd.local_rank())
        else:
            if (hvd.local_rank() < len(args.gpus)):
                print("rank, ", hvd.local_rank(), " is runing on ",
                      args.gpus[hvd.local_rank()])
                torch.cuda.set_device(args.gpus[hvd.local_rank()])
            else:
                print("rank, ", hvd.local_rank(), " is runing on ",
                      args.gpus[0])
                torch.cuda.set_device(args.gpus[0])
        cudnn.benchmark = True
    else:
        args.gpus = None

    # create model
    logging.info("creating model %s", args.model)
    model = models.__dict__[args.model]
    model_config = {
        'input_size': args.input_size,
        'dataset': args.dataset,
        'depth': args.resnet_depth
    }

    if args.model_config is not '':
        model_config = dict(model_config, **literal_eval(args.model_config))

    model = model(**model_config)
    logging.info("created model with configuration: %s", model_config)

    # optionally resume from a checkpoint
    if args.evaluate:
        if not os.path.isfile(args.evaluate):
            parser.error('invalid checkpoint: {}'.format(args.evaluate))
        checkpoint = torch.load(args.evaluate)
        model.load_state_dict(checkpoint['state_dict'])
        logging.info("loaded checkpoint '%s' (epoch %s)", args.evaluate,
                     checkpoint['epoch'])
    elif args.resume:
        checkpoint_file = args.resume
        if os.path.isdir(checkpoint_file):
            results.load(os.path.join(checkpoint_file, 'results.csv'))
            checkpoint_file = os.path.join(checkpoint_file,
                                           'model_best.pth.tar')
        if os.path.isfile(checkpoint_file):
            logging.info("loading checkpoint '%s'", args.resume)
            checkpoint = torch.load(checkpoint_file)
            args.start_epoch = checkpoint['epoch'] + 1
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            logging.info("loaded checkpoint '%s' (epoch %s)", checkpoint_file,
                         checkpoint['epoch'])
        else:
            logging.error("no checkpoint found at '%s'", args.resume)

    num_parameters = sum([l.nelement() for l in model.parameters()])
    logging.info("number of parameters: %d", num_parameters)

    # Data loading code
    default_transform = {
        'train':
        get_transform(args.dataset, input_size=args.input_size, augment=True),
        'eval':
        get_transform(args.dataset, input_size=args.input_size, augment=False)
    }
    transform = getattr(model, 'input_transform', default_transform)
    regime = getattr(
        model,
        'regime',
        {
            0: {
                'optimizer': args.optimizer,
                'lr': args.lr
                #'momentum': args.momentum,
                #'weight_decay': args.weight_decay
            }
        })
    adapted_regime = {}
    logging.info('self-defined momentum : %f, weight_decay : %f',
                 args.momentum, args.weight_decay)
    for e, v in regime.items():
        if args.lr_bb_fix and 'lr' in v:
            # v['lr'] *= (args.batch_size / args.mini_batch_size) ** 0.5
            v['lr'] *= (args.batch_size * hvd.size() / 128)**0.5
        adapted_regime[e] = v
    regime = adapted_regime

    # define loss function (criterion) and optimizer
    criterion = getattr(model, 'criterion', nn.CrossEntropyLoss)()
    criterion.type(args.type)
    model.type(args.type)

    #val_data = get_dataset(args.dataset, 'val', transform['eval'])
    #val_loader = torch.utils.data.DataLoader(
    #    val_data,
    #    batch_size=args.batch_size, shuffle=False,
    #    num_workers=args.workers, pin_memory=True)
    val_loader = None

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    #train_data = get_dataset(args.dataset, 'train', transform['train'])
    #train_loader = torch.utils.data.DataLoader(
    #    train_data,
    #    batch_size=args.batch_size, shuffle=True,
    #    num_workers=args.workers, pin_memory=True)
    train_loader = None

    if hvd.rank() == 0:
        logging.info('training regime: %s', regime)
        print({
            i: list(w.size())
            for (i, w) in enumerate(list(model.parameters()))
        })
    init_weights = [w.data.cpu().clone() for w in list(model.parameters())]

    U = []
    V = []
    print("current rank ", hvd.rank(), "local_rank ", hvd.local_rank(), \
            " USE_PRUNING ", args.use_pruning)
    print("model ", args.model, " use_nesterov ", args.use_nesterov)

    #TODO u, v will be cleared at the begining of each epoch
    if args.use_pruning:
        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
        if args.gpus is not None:
            optimizer = DGCDistributedOptimizer(
                optimizer,
                named_parameters=model.named_parameters(),
                use_gpu=True,
                momentum=0.9,
                weight_decay=1e-4)
        else:
            optimizer = DGCDistributedOptimizer(
                optimizer,
                named_parameters=model.named_parameters(),
                use_gpu=False,
                momentum=0.9,
                weight_decay=1e-4)
    else:
        if args.use_hvddist:
            print("use orignal hvd DistributedOptimizer")
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=args.lr,
                                        momentum=0.9,
                                        weight_decay=1e-4,
                                        nesterov=True)
            #optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters())
            optimizer = myhvdOptimizer(
                optimizer, named_parameters=model.named_parameters())
        else:
            optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
            if args.gpus is not None:
                optimizer = DGCDistributedOptimizer(
                    optimizer,
                    named_parameters=model.named_parameters(),
                    use_gpu=True,
                    momentum=0.9,
                    weight_decay=1e-4,
                    use_allgather=False)
            else:
                optimizer = DGCDistributedOptimizer(
                    optimizer,
                    named_parameters=model.named_parameters(),
                    use_gpu=False,
                    momentum=0.9,
                    weight_decay=1e-4,
                    use_allgather=False)

    hvd.broadcast_parameters(model.state_dict(), root_rank=0)

    global_begin_time = time.time()
    for epoch in range(args.start_epoch, args.epochs // hvd.size()):
        #optimizer = adjust_optimizer(optimizer, epoch, regime)
        for e, v in regime.items():
            if epoch == e // hvd.size():
                for param_group in optimizer.param_groups:
                    param_group['lr'] = v['lr']
                break

        # train for one epoch
        train_result = train(train_loader, model, criterion, epoch, optimizer,
                             U, V)
        sys.exit()

        train_loss, train_prec1, train_prec5, U, V = [
            train_result[r] for r in ['loss', 'prec1', 'prec5', 'U', 'V']
        ]

        # evaluate on validation set
        val_result = validate(val_loader, model, criterion, epoch)
        val_loss, val_prec1, val_prec5 = [
            val_result[r] for r in ['loss', 'prec1', 'prec5']
        ]

        # remember best prec@1 and save checkpoint
        is_best = val_prec1 > best_prec1
        best_prec1 = max(val_prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': args.model,
                'config': args.model_config,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'regime': regime
            },
            is_best,
            path=save_path)
        if hvd.rank() == 0:
            if torch.__version__ == "0.4.0":
                logging.info('\n Epoch: {0}\t'
                             'Training Loss {train_loss:.4f} \t'
                             'Training Prec@1 {train_prec1:.3f} \t'
                             'Training Prec@5 {train_prec5:.3f} \t'
                             'Validation Loss {val_loss:.4f} \t'
                             'Validation Prec@1 {val_prec1:.3f} \t'
                             'Validation Prec@5 {val_prec5:.3f} \n'.format(
                                 epoch + 1,
                                 train_loss=train_loss.cpu().numpy(),
                                 val_loss=val_loss.cpu().numpy(),
                                 train_prec1=train_prec1.cpu().numpy(),
                                 val_prec1=val_prec1.cpu().numpy(),
                                 train_prec5=train_prec5.cpu().numpy(),
                                 val_prec5=val_prec5.cpu().numpy()))
            else:
                logging.info('\n Epoch: {0}\t'
                             'Training Loss {train_loss:.4f} \t'
                             'Training Prec@1 {train_prec1:.3f} \t'
                             'Training Prec@5 {train_prec5:.3f} \t'
                             'Validation Loss {val_loss:.4f} \t'
                             'Validation Prec@1 {val_prec1:.3f} \t'
                             'Validation Prec@5 {val_prec5:.3f} \n'.format(
                                 epoch + 1,
                                 train_loss=train_loss,
                                 val_loss=val_loss,
                                 train_prec1=train_prec1,
                                 val_prec1=val_prec1,
                                 train_prec5=train_prec5,
                                 val_prec5=val_prec5))

        #Enable to measure more layers
        idxs = [0]  #,2,4,6,7,8,9,10]#[0, 12, 45, 63]

        step_dist_epoch = {
            'step_dist_n%s' % k: (w.data.cpu() - init_weights[k]).norm()
            for (k, w) in enumerate(list(model.parameters())) if k in idxs
        }

        if (hvd.rank() == 0):
            current_time = time.time()
            if hvd.rank() == 0:
                results.add(epoch=epoch + 1,
                            train_loss=train_loss.cpu().numpy(),
                            val_loss=val_loss.cpu().numpy(),
                            train_error1=100 - train_prec1.cpu().numpy(),
                            val_error1=100 - val_prec1.cpu().numpy(),
                            train_error5=100 - train_prec5.cpu().numpy(),
                            val_error5=100 - val_prec5.cpu().numpy(),
                            eslapse=current_time - global_begin_time)
            else:
                results.add(epoch=epoch + 1,
                            train_loss=train_loss,
                            val_loss=val_loss,
                            train_error1=100 - train_prec1,
                            val_error1=100 - val_prec1,
                            train_error5=100 - train_prec5,
                            val_error5=100 - val_prec5,
                            eslapse=current_time - global_begin_time)

            #results.plot(x='epoch', y=['train_loss', 'val_loss'],
            #             title='Loss', ylabel='loss')
            #results.plot(x='epoch', y=['train_error1', 'val_error1'],
            #             title='Error@1', ylabel='error %')
            #results.plot(x='epoch', y=['train_error5', 'val_error5'],
            #             title='Error@5', ylabel='error %')

            #for k in idxs:
            #    results.plot(x='epoch', y=['step_dist_n%s' % k],
            #                 title='step distance per epoch %s' % k,
            #                 ylabel='val')

            results.save()
Ejemplo n.º 27
0
def main():
    torch.manual_seed(123)
    global args, best_prec1
    best_prec1 = 0
    args = parser.parse_args()
    if args.regime_bb_fix:
        args.epochs *= int(ceil(args.batch_size / 256.))

    if args.evaluate:
        args.results_dir = '/home/shai/tensorflow/generated_data/Pytorch/'
    if args.save is '':
        args.save = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    save_path = os.path.join(args.results_dir, args.save)
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    writer = SummaryWriter(save_path)
    setup_logging(os.path.join(save_path, 'log.txt'))
    results_file = os.path.join(save_path, 'results.%s')
    results = ResultsLog(results_file % 'csv', results_file % 'html')

    logging.info("saving to %s", save_path)
    logging.debug("run arguments: %s", args)

    if 'cuda' in args.type:
        torch.cuda.manual_seed(123)
        args.gpus = [int(i) for i in args.gpus.split(',')]
        torch.cuda.set_device(args.gpus[0])
        cudnn.benchmark = True
    else:
        args.gpus = None

    # create model
    logging.info("creating model %s", args.model)
    model = models.__dict__[args.model]
    model_config = {'input_size': args.input_size, 'dataset': args.dataset}

    if args.model_config is not '':
        model_config = dict(model_config, **literal_eval(args.model_config))

    model = model(**model_config)
    logging.info("created model with configuration: %s", model_config)

    # optionally resume from a checkpoint
    if args.evaluate:
        if not os.path.isfile(args.evaluate):
            parser.error('invalid checkpoint: {}'.format(args.evaluate))
        checkpoint = torch.load(args.evaluate)
        model.load_state_dict(checkpoint['state_dict'])
        logging.info("loaded checkpoint '%s' (epoch %s)", args.evaluate,
                     checkpoint['epoch'])
    elif args.resume:
        checkpoint_file = args.resume
        if os.path.isdir(checkpoint_file):
            results.load(os.path.join(checkpoint_file, 'results.csv'))
            checkpoint_file = os.path.join(checkpoint_file,
                                           'model_best.pth.tar')
        if os.path.isfile(checkpoint_file):
            logging.info("loading checkpoint '%s'", args.resume)
            checkpoint = torch.load(checkpoint_file)
            args.start_epoch = checkpoint['epoch'] - 1
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            logging.info("loaded checkpoint '%s' (epoch %s)", checkpoint_file,
                         checkpoint['epoch'])
        else:
            logging.error("no checkpoint found at '%s'", args.resume)

    num_parameters = sum([l.nelement() for l in model.parameters()])
    logging.info("number of parameters: %d", num_parameters)

    # Data loading code
    default_transform = {
        'train':
        get_transform(args.dataset, input_size=args.input_size, augment=True),
        'eval':
        get_transform(args.dataset, input_size=args.input_size, augment=False)
    }
    transform = getattr(model, 'input_transform', default_transform)
    regime = getattr(
        model, 'regime', {
            0: {
                'optimizer': args.optimizer,
                'lr': args.lr,
                'momentum': args.momentum,
                'weight_decay': args.weight_decay
            }
        })
    adapted_regime = {}
    max_lr = args.lr * (args.batch_size / 256.)
    for e, v in regime.items():
        if args.lr_bb_fix and 'lr' in v:
            v['lr'] *= (args.batch_size / 256.)**0.5
        if args.regime_bb_fix:
            e *= ceil(args.batch_size / 256.)
        adapted_regime[e] = v
    regime = adapted_regime

    # define loss function (criterion) and optimizer
    criterion = getattr(model, 'criterion', nn.CrossEntropyLoss)()
    criterion.type(args.type)
    model.type(args.type)

    val_data = get_dataset(args.dataset, 'val', transform['eval'])
    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    train_data = get_dataset(args.dataset, 'train', transform['train'])
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
    optimizer = SGDAdjustOptimizer(optimizer,
                                   iters_per_adjust=args.iters_per_adjust,
                                   disable_lr_change=args.disable_lr_change,
                                   writer=writer,
                                   sqrt_factor=args.sqrt_factor,
                                   max_lr=max_lr)

    logging.info('training regime: %s', regime)
    init_weights = [w.data.cpu().clone() for w in list(model.parameters())]

    for epoch in range(args.start_epoch, args.epochs):
        if args.disable_lr_change == True:
            tmp_base_optimizer = adjust_optimizer(optimizer.base_optimizer,
                                                  epoch, regime)
            optimizer.set_base_optimizer(tmp_base_optimizer)

        # train for one epoch
        train_loss, train_prec1, train_prec5 = train(train_loader, model,
                                                     criterion, epoch,
                                                     optimizer)

        # evaluate on validation set
        val_loss, val_prec1, val_prec5 = validate(val_loader, model, criterion,
                                                  epoch)

        # remember best prec@1 and save checkpoint
        is_best = val_prec1 > best_prec1
        best_prec1 = max(val_prec1, best_prec1)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': args.model,
                'config': args.model_config,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'regime': regime
            },
            is_best,
            path=save_path)
        logging.info('\n Epoch: {0}\t'
                     'Training Loss {train_loss:.4f} \t'
                     'Training Prec@1 {train_prec1:.3f} \t'
                     'Training Prec@5 {train_prec5:.3f} \t'
                     'Validation Loss {val_loss:.4f} \t'
                     'Validation Prec@1 {val_prec1:.3f} \t'
                     'Validation Prec@5 {val_prec5:.3f} \n'.format(
                         epoch + 1,
                         train_loss=train_loss,
                         val_loss=val_loss,
                         train_prec1=train_prec1,
                         val_prec1=val_prec1,
                         train_prec5=train_prec5,
                         val_prec5=val_prec5))

        writer.add_scalar('train_loss', train_loss, epoch)
        writer.add_scalar('val_loss', val_loss, epoch)
        writer.add_scalar('train_error1', 100 - train_prec1, epoch)
        writer.add_scalar('val_error1', 100 - val_prec1, epoch)
        writer.add_scalar('train_error5', 100 - train_prec5, epoch)
        writer.add_scalar('val_error5', 100 - val_prec5, epoch)

        #Enable to measure more layers
        idxs = [0]  #,2,4,6,7,8,9,10]#[0, 12, 45, 63]
        step_dist_epoch = {
            'step_dist_n%s' % k: (w.data.cpu() - init_weights[k]).norm()
            for (k, w) in enumerate(list(model.parameters())) if k in idxs
        }
        results.add(epoch=epoch + 1,
                    train_loss=train_loss,
                    val_loss=val_loss,
                    train_error1=100 - train_prec1,
                    val_error1=100 - val_prec1,
                    train_error5=100 - train_prec5,
                    val_error5=100 - val_prec5,
                    **step_dist_epoch)
        results.plot(x='epoch',
                     y=['train_loss', 'val_loss'],
                     title='Loss',
                     ylabel='loss')
        results.plot(x='epoch',
                     y=['train_error1', 'val_error1'],
                     title='Error@1',
                     ylabel='error %')
        results.plot(x='epoch',
                     y=['train_error5', 'val_error5'],
                     title='Error@5',
                     ylabel='error %')
        for k in idxs:
            results.plot(x='epoch',
                         y=['step_dist_n%s' % k],
                         title='step distance per epoch %s' % k,
                         ylabel='val')

        results.save()
    writer.close()
Ejemplo n.º 28
0
        datay = datatype[1]
        for smpl in np.split(np.random.permutation(range(dataX.shape[0])), 10):
            ops = opfun(dataX[smpl])
            tgts = Variable(torch.from_numpy(datay[smpl]).long().squeeze())
            var = F.nll_loss(ops, tgts).data.numpy() / 10
            data_for_plotting[i, j] += accfun(ops, datay[smpl]) / 10.
        j += 1
    print(data_for_plotting[i])
    np.save('ShallowNetC3-intermediate-values', data_for_plotting)
    i += 1



# # Data loading code
default_transform = {
    'train': get_transform("cifar10",
                           input_size=None, augment=True),
    'eval': get_transform("cifar10",
                          input_size=None, augment=False)
}
transform = getattr(model, 'input_transform', default_transform)

# define loss function (criterion) and optimizer
criterion = getattr(model, 'criterion', nn.CrossEntropyLoss)()
criterion.type(torch.FloatTensor)
# #model.type(torch.cuda.FloatTensor)

i = 0
for batch_size in batch_range:
    mydict = {}
    batchmodel = torch.load("./models/ShallowNetCIFAR100BatchSize" + str(batch_size) + ".pth")
    for key, value in batchmodel.items():
Ejemplo n.º 29
0
def main():
    global args, best_prec
    global progress, task2, task3
    best_prec = 0
    args = parser.parse_args()

    if args.evaluate:
        args.results_dir = './tmp'
    if args.save is '':
        args.save = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    save_path = os.path.join(args.results_dir, args.save)
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    setup_logging(os.path.join(save_path, 'log.txt'))
    results_file = os.path.join(save_path, 'results.%s')
    results = ResultsLog(results_file % 'csv', results_file % 'html')

    logging.info("saving to %s", save_path)
    logging.info("run arguments: %s", args)

    if 'cuda' in args.type:
        args.gpus = [int(i) for i in args.gpus.split(',')]
        torch.cuda.set_device(args.gpus[0])
        cudnn.benchmark = True
    else:
        args.gpus = None

    # create model
    logging.info("creating model %s", args.model)
    model = models.__dict__[args.model]
    model_config = {'input_size': args.input_size, 'dataset': args.dataset}

    if args.model_config is not '':
        model_config = dict(model_config, **literal_eval(args.model_config))

    model = model(**model_config)
    logging.info("created model with configuration: %s", model_config)

    # optionally resume from a checkpoint
    if args.evaluate:
        if not os.path.isfile(args.evaluate):
            parser.error('invalid checkpoint: {}'.format(args.evaluate))
        checkpoint = torch.load(args.evaluate)
        model.load_state_dict(checkpoint['state_dict'])
        logging.info("loaded checkpoint '%s' (epoch %s)", args.evaluate,
                     checkpoint['epoch'])
    elif args.resume:
        checkpoint_file = args.resume
        if os.path.isdir(checkpoint_file):
            results.load(os.path.join(checkpoint_file, 'results.csv'))
            checkpoint_file = os.path.join(checkpoint_file,
                                           'checkpoint.pth.tar')
        if os.path.isfile(checkpoint_file):
            logging.info("loading checkpoint '%s'", args.resume)
            checkpoint = torch.load(checkpoint_file)
            args.start_epoch = checkpoint['epoch']
            best_prec = checkpoint['best_prec']
            model.load_state_dict(checkpoint['state_dict'])
            logging.info("loaded checkpoint '%s' (epoch %s)", checkpoint_file,
                         checkpoint['epoch'])
        else:
            logging.error("no checkpoint found at '%s'", args.resume)

    # Data loading code
    default_transform = {
        'train':
        get_transform(args.dataset, input_size=args.input_size, augment=False),
        'eval':
        get_transform(args.dataset, input_size=args.input_size, augment=False)
    }
    transform = getattr(model, 'input_transform', default_transform)
    regime = getattr(
        model, 'regime', {
            0: {
                'optimizer': args.optimizer,
                'lr': args.lr,
                'momentum': args.momentum,
                'weight_decay': args.weight_decay
            }
        })

    # define loss function (criterion) and optimizer
    criterion = getattr(model, 'criterion', nn.CrossEntropyLoss)()
    criterion.type(args.type)
    model.type(args.type)

    train_data = get_dataset(args.dataset, 'train', transform['train'])
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)
    logging.info('train dataset size: %d', len(train_data))

    val_data = get_dataset(args.dataset, 'eval', transform['eval'])
    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)
    logging.info('validate dataset size: %d', len(val_data))

    # print net struct
    if args.dataset == 'mnist':
        summary(model, (1, 28, 28))
    elif args.dataset == 'cifar10':
        summary(model, (3, 32, 32))

    if args.evaluate:
        with Progress(
                "[progress.description]{task.description}{task.completed}/{task.total}",
                BarColumn(),
                "[progress.percentage]{task.percentage:>3.0f}%",
                TimeRemainingColumn(),
                auto_refresh=False) as progress:
            task3 = progress.add_task("[yellow]validating:",
                                      total=math.ceil(
                                          len(val_data) / args.batch_size))
            val_loss, val_prec1 = validate(val_loader, model, criterion, 0)
            logging.info('Evaluate {0}\t'
                         'Validation Loss {val_loss:.4f} \t'
                         'Validation Prec@1 {val_prec1:.3f} \t'.format(
                             args.evaluate,
                             val_loss=val_loss,
                             val_prec1=val_prec1))
        return

    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
    logging.info('training regime: %s', regime)

    # restore results
    train_loss_list, train_prec_list = [], []
    val_loss_list, val_prec_list = [], []

    # print progressor
    with Progress(
            "[progress.description]{task.description}{task.completed}/{task.total}",
            BarColumn(),
            "[progress.percentage]{task.percentage:>3.0f}%",
            TimeRemainingColumn(),
            auto_refresh=False) as progress:
        task1 = progress.add_task("[red]epoch:", total=args.epochs)
        task2 = progress.add_task("[blue]training:",
                                  total=math.ceil(
                                      len(train_data) / args.batch_size))
        task3 = progress.add_task("[yellow]validating:",
                                  total=math.ceil(
                                      len(val_data) / args.batch_size))

        for i in range(args.start_epoch):
            progress.update(task1, advance=1, refresh=True)

        begin = time.time()
        for epoch in range(args.start_epoch, args.epochs):
            start = time.time()
            optimizer = adjust_optimizer(optimizer, epoch, regime)

            # train for one epoch
            train_loss, train_prec = train(train_loader, model, criterion,
                                           epoch, optimizer)
            train_loss_list.append(train_loss)
            train_prec_list.append(train_prec)

            # evaluate on validation set
            val_loss, val_prec = validate(val_loader, model, criterion, epoch)
            val_loss_list.append(val_loss)
            val_prec_list.append(val_prec)

            # remember best prec@1 and save checkpoint
            is_best = val_prec > best_prec
            best_prec = max(val_prec, best_prec)

            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'model': args.model,
                    'config': args.model_config,
                    'state_dict': model.state_dict(),
                    'best_prec': best_prec,
                    'regime': regime
                },
                is_best,
                path=save_path)
            logging.info(' Epoch: [{0}/{1}] Cost_Time: {2:.2f}s\n'
                         ' Training Loss {train_loss:.4f} \t'
                         'Training Prec {train_prec1:.3f} \t'
                         'Validation Loss {val_loss:.4f} \t'
                         'Validation Prec {val_prec1:.3f} \t'.format(
                             epoch + 1,
                             args.epochs,
                             time.time() - start,
                             train_loss=train_loss,
                             val_loss=val_loss,
                             train_prec1=train_prec,
                             val_prec1=val_prec))

            results.add(epoch=epoch + 1,
                        train_loss=train_loss,
                        val_loss=val_loss,
                        train_error1=100 - train_prec,
                        val_error1=100 - val_prec)
            results.save()

            # update progressor
            progress.update(task1, advance=1, refresh=True)

    logging.info(
        '----------------------------------------------------------------\n'
        'Whole Cost Time: {0:.2f}s      Best Validation Prec {1:.3f}\n'
        '-----------------------------------------------------------------'.
        format(time.time() - begin, best_prec))

    epochs = list(range(args.epochs))
    draw2(epochs, train_loss_list, val_loss_list, train_prec_list,
          val_prec_list)
Ejemplo n.º 30
0
def main():
    #torch.manual_seed(123)
    global args, best_prec1
    best_prec1 = 0
    args = parser.parse_args()
    args.epochs *= args.regime_bb_multi
    if args.regime_bb_fix:
        args.epochs *= (int)(ceil(args.batch_size * args.batch_multiplier /
                                  args.mini_batch_size))

    if args.evaluate:
        args.results_dir = '/tmp'
    if args.save is '':
        args.save = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    save_path = os.path.join(args.results_dir, args.save)
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    else:
        raise OSError('Directory {%s} exists. Use a new one.' % save_path)

    setup_logging(os.path.join(save_path, 'log.txt'))
    results_file = os.path.join(save_path, 'results.%s')
    results = ResultsLog(results_file % 'csv', results_file % 'html')

    logging.info("saving to %s", save_path)
    logging.info("run arguments: %s", args)

    if 'cuda' in args.type:
        #torch.cuda.manual_seed_all(123)
        args.gpus = [int(i) for i in args.gpus.split(',')]
        torch.cuda.set_device(args.gpus[0])
        cudnn.benchmark = True
    else:
        args.gpus = None

    # create model
    logging.info("creating model %s", args.model)
    model = models.__dict__[args.model]
    model_config = {
        'input_size': args.input_size,
        'dataset': args.dataset,
        'noise': args.relu_noise
    }

    if args.model_config is not '':
        model_config = dict(model_config, **literal_eval(args.model_config))

    model = model(**model_config)
    logging.info("created model with configuration: %s", model_config)

    # optionally resume from a checkpoint
    if args.evaluate:
        if not os.path.isfile(args.evaluate):
            parser.error('invalid checkpoint: {}'.format(args.evaluate))
        checkpoint = torch.load(args.evaluate)
        model.load_state_dict(checkpoint['state_dict'])
        logging.info("loaded checkpoint '%s' (epoch %s)", args.evaluate,
                     checkpoint['epoch'])
    elif args.resume:
        checkpoint_file = args.resume
        if os.path.isdir(checkpoint_file):
            results.load(os.path.join(checkpoint_file, 'results.csv'))
            checkpoint_file = os.path.join(checkpoint_file,
                                           'model_best.pth.tar')
        if os.path.isfile(checkpoint_file):
            logging.info("loading checkpoint '%s'", args.resume)
            checkpoint = torch.load(checkpoint_file)
            args.start_epoch = checkpoint['epoch'] + 1
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            logging.info("loaded checkpoint '%s' (epoch %s)", checkpoint_file,
                         checkpoint['epoch'])
        else:
            logging.error("no checkpoint found at '%s'", args.resume)

    num_parameters = sum([l.nelement() for l in model.parameters()])
    logging.info("number of parameters: %d", num_parameters)

    # Data loading code
    default_transform = {
        'train':
        get_transform(args.dataset,
                      input_size=args.input_size,
                      augment=args.augment),
        'eval':
        get_transform(args.dataset, input_size=args.input_size, augment=False)
    }
    transform = getattr(model, 'input_transform', default_transform)
    if args.optimizer == 'Adam':
        assert (args.weight_decay is not None)
        regime = {
            0: {
                'optimizer': args.optimizer,
                'lr': args.lr,
                'weight_decay': args.weight_decay
            }
        }
    else:
        regime = getattr(
            model, 'regime', {
                0: {
                    'optimizer': args.optimizer,
                    'lr': args.lr,
                    'momentum': args.momentum,
                    'weight_decay': args.weight_decay
                }
            })
        if args.weight_decay:
            regime[0]['weight_decay'] = args.weight_decay
    adapted_regime = {}
    for e, v in regime.items():
        if args.lr_bb_fix and 'lr' in v:
            if args.lr_fix_policy == 'sqrt':
                v['lr'] *= (args.batch_size * args.batch_multiplier /
                            args.mini_batch_size)**0.5
            elif args.lr_fix_policy == 'linear':
                v['lr'] *= (args.batch_size * args.batch_multiplier /
                            args.mini_batch_size)
            else:
                raise ValueError('Unknown --lr_fix_policy')
        e *= args.regime_bb_multi
        if args.regime_bb_fix:
            e *= ceil(args.batch_size * args.batch_multiplier /
                      args.mini_batch_size)
        adapted_regime[e] = v
    regime = adapted_regime

    # define loss function (criterion) and optimizer
    criterion = getattr(model, 'criterion', nn.CrossEntropyLoss)()
    criterion.type(args.type)
    model.type(args.type)

    val_data = get_dataset(args.dataset, 'val', transform['eval'])
    val_loader = torch.utils.data.DataLoader(val_data,
                                             batch_size=args.batch_size,
                                             shuffle=False,
                                             num_workers=args.workers,
                                             pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion, 0)
        return

    train_data = get_dataset(args.dataset, 'train', transform['train'])
    train_loader = torch.utils.data.DataLoader(train_data,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
    logging.info('training regime: %s', regime)
    print(
        {i: list(w.size())
         for (i, w) in enumerate(list(model.parameters()))})
    init_weights = [w.data.cpu().clone() for w in list(model.parameters())]

    for epoch in range(args.start_epoch, args.epochs):
        optimizer = adjust_optimizer(optimizer, epoch, regime)

        # train for one epoch
        train_result = train(train_loader, model, criterion, epoch, optimizer)
        train_loss, train_prec1, train_prec5 = [
            train_result[r] for r in ['loss', 'prec1', 'prec5']
        ]

        # evaluate on validation set
        val_result = validate(val_loader, model, criterion, epoch)
        val_loss, val_prec1, val_prec5 = [
            val_result[r] for r in ['loss', 'prec1', 'prec5']
        ]

        # remember best prec@1 and save checkpoint
        is_best = val_prec1 > best_prec1
        best_prec1 = max(val_prec1, best_prec1)
        if is_best:
            logging.info('\n Epoch: {0}\t'
                         'Best Val Prec@1 {val_prec1:.3f} '
                         'with Val Prec@5 {val_prec5:.3f} \n'.format(
                             epoch + 1,
                             val_prec1=val_prec1,
                             val_prec5=val_prec5))
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'model': args.model,
                'config': args.model_config,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'regime': regime
            },
            is_best,
            path=save_path,
            save_all=args.save_all)
        logging.info('\n Epoch: {0}\t'
                     'Training Loss {train_loss:.4f} \t'
                     'Training Prec@1 {train_prec1:.3f} \t'
                     'Training Prec@5 {train_prec5:.3f} \t'
                     'Validation Loss {val_loss:.4f} \t'
                     'Validation Prec@1 {val_prec1:.3f} \t'
                     'Validation Prec@5 {val_prec5:.3f} \n'.format(
                         epoch + 1,
                         train_loss=train_loss,
                         val_loss=val_loss,
                         train_prec1=train_prec1,
                         val_prec1=val_prec1,
                         train_prec5=train_prec5,
                         val_prec5=val_prec5))

        #Enable to measure more layers
        idxs = [0]  #,2,4,6,7,8,9,10]#[0, 12, 45, 63]

        step_dist_epoch = {
            'step_dist_n%s' % k: (w.data.cpu() - init_weights[k]).norm()
            for (k, w) in enumerate(list(model.parameters())) if k in idxs
        }

        results.add(epoch=epoch + 1,
                    train_loss=train_loss,
                    val_loss=val_loss,
                    train_error1=100 - train_prec1,
                    val_error1=100 - val_prec1,
                    train_error5=100 - train_prec5,
                    val_error5=100 - val_prec5,
                    **step_dist_epoch)

        results.plot(x='epoch',
                     y=['train_loss', 'val_loss'],
                     title='Loss',
                     ylabel='loss')
        results.plot(x='epoch',
                     y=['train_error1', 'val_error1'],
                     title='Error@1',
                     ylabel='error %')
        results.plot(x='epoch',
                     y=['train_error5', 'val_error5'],
                     title='Error@5',
                     ylabel='error %')

        for k in idxs:
            results.plot(x='epoch',
                         y=['step_dist_n%s' % k],
                         title='step distance per epoch %s' % k,
                         ylabel='val')

        results.save()