def print_without_img(sample_list): for sample in sample_list: print('--~~~~~~~~~~~~~~~--', flush=True) print('image_height', sample["image_height"], flush=True) print('image_width', sample["image_width"], flush=True) print('image_id', sample["image_id"], flush=True) if sample['obbox'] is None and sample['bbox'] is None: print('Obbox and Bbox is None', flush=True) elif sample['obbox'] is None: print('Obbox is None', flush=True) print('bbox', u.get_nums_from_bbox(sample["bbox"]), flush=True) elif sample['bbox'] is None: print('BBox is None', flush=True) print('obbox', u.get_nums_from_bbox(sample["obbox"]), flush=True) else: print('obbox', u.get_nums_from_bbox(sample["obbox"]), flush=True) print('bbox', u.get_nums_from_bbox(sample["bbox"]), flush=True) print('label', sample["label"], flush=True) print('has_lesion', sample["has_lesion"], flush=True) if sample['image'] is None: print('image is None', flush=True) else: print('image is NOT None', flush=True) print('image shape', sample['image'].shape, flush=True) if sample['original'] is None: print('original is None', flush=True) else: print('original is NOT None', flush=True) print('original shape', sample['original'].shape, flush=True) print('-------------------', flush=True)
def __call__(self, sample_list): # print('Start quadrupling',flush=True) # print_without_img(sample_list) for sample in sample_list: if sample['has_lesion'] == 1: # print('-----') # print('bbox',sample['obbox']) r, c, rn, cn = u.get_nums_from_bbox(sample['bbox']) bbox_list = [] factor_list = [4.15,2.92,2.59,\ 2.33,2.13,1.97,1.83,\ 1.71,1.6,1.51,1.42,\ 1.35,1.28,1.21,1.15,\ 1.09,1.04] cat_list = [0,0,1,1,1,1,2,2,2,2,3,3,3,3,4,4,4,4] for factor in factor_list: row_lim = int(r - (factor-1) * rn) if row_lim < 1: row_lim = 1 if r <= 1: r = 2 col_lim = int(c - (factor-1) * cn) if col_lim < 1: col_lim = 1 if c <= 1: c = 2 try: new_r = np.random.randint(row_lim, r) new_c = np.random.randint(col_lim, c) except ValueError: raise ValueError( 'r {}, c {}, rn {}, cn {}, row_lim {}, col_lim {}'.format(r, c, rn, cn, row_lim, col_lim)) new_rn = int((factor) * rn) new_cn = int((factor) * cn) new_r, new_c, new_rn, new_cn = check_bbox(new_r, new_c, new_rn, new_cn, sample['image_height'], sample['image_width']) bbox_list.append(u.create_bbox(new_r, new_c, new_rn, new_cn)) bbox_list.append(sample['bbox']) else: bbox_list = [] factor_list = [0.8,0.7,0.6,0.5,0.2,0.1] cat_list = [0,0,0,0,0,0] for factor in factor_list: smaller_side = min(sample['image_height'], sample['image_width']) rn = cn = int(smaller_side * factor) r = np.random.randint(1, sample['image_height'] - rn) c = np.random.randint(1, sample['image_width'] - cn) # Check if I have run out of image if r + rn > sample['image_height']: rn = cn = sample['image_height'] - r if c + cn > sample['image_width']: rn = cn = sample['image_width'] - c r, c, rn, cn = check_bbox(r, c, rn, cn, sample['image_height'], sample['image_width']) bbox_list.append(u.create_bbox(r, c, rn, cn)) return multiply_sample(sample, bbox_list,cat_list)
def test_triple(no_imgs=10): checkpoint_dir = os.path.abspath('../checkpoints/test_steps_04') if not os.path.isdir(checkpoint_dir): os.makedirs(checkpoint_dir) # print('------------------------') _, train_loader = get_dataloaders(checkpoint_dir, rsyncing=False, selective_sampling=False, warmup_trainer=None, batch_size=1, num_workers=0, data_aug_vec=[0, 0, 0, 0]) # print('------------------------') cnt = 0 for i, batch in enumerate(train_loader): # print('triple {}'.format(i)) # print('Len batch', len(batch['has_lesion_batch'])) if len(batch['image_batch']) > 1: for j in range(len(batch['image_batch'])): if batch['bbox_batch'][j]: r, c, rn, cn = u.get_nums_from_bbox(batch['bbox_batch'][j]) ro, co, rno, cno = u.get_nums_from_bbox(batch['obbox_batch'][j]) bbox_flag = True else: r = c = rn = cn = ro = co = rno = cno = 0 bbox_flag = False save_image_with_bb(j, batch, os.path.join(checkpoint_dir, 'triple_{}.png'.format(cnt)), bbox_flag=bbox_flag, r=r, c=c, rn=rn, cn=cn, ro=ro, co=co, rno=rno, cno=cno) cnt += 1 elif len(batch['image_batch']) == 1: if batch['bbox_batch'][0]: r, c, rn, cn = u.get_nums_from_bbox(batch['bbox_batch'][0]) ro, co, rno, cno = u.get_nums_from_bbox(batch['obbox_batch'][0]) bbox_flag = True else: r = c = rn = cn = ro = co = rno = cno = 0 bbox_flag = False save_image_with_bb(0, batch, os.path.join(checkpoint_dir, 'triple_{}.png'.format(cnt)), bbox_flag=bbox_flag, r=r, c=c, rn=rn, cn=cn, ro=ro, co=co, rno=rno, cno=cno) cnt += 1 # print('-+-+-+------------------') if i >= no_imgs: break
def __call__(self, sample_list): """ Args: sample: image, bbox, label Returns: PIL Image: Cropped image """ out = [] # print('Start cropping',flush=True) # print_without_img(sample_list) for sample in sample_list: if sample['has_lesion'] == 1: # If it is known where the lesion is, draw random padding percentage # and increase bbox respectively r, c, rn, cn = u.get_nums_from_bbox(sample['bbox']) # print('from bbox',r,c,rn,cn,flush=True) # Check if I have run out of image r, c, rn, cn = check_bbox(r, c, rn, cn, sample['image_height'], sample['image_width']) else: rn = cn = np.random.randint(3, sample['image_width'] - 3) r = np.random.randint(1, sample['image_height'] - rn) c = np.random.randint(1, sample['image_width'] - cn) # Check if I have run out of image if r + rn > sample['image_height']: rn = cn = sample['image_height'] - r if c + cn > sample['image_width']: rn = cn = sample['image_width'] - c r, c, rn, cn = check_bbox(r, c, rn, cn, sample['image_height'], sample['image_width']) # print('====================',flush=True) # print('before resize',r,c,rn,cn,flush=True) # print('img size',sample['image'].shape,flush=True) # print('====================',flush=True) im_copy = np.copy(sample['image']) resized = im.resize(sample['image'], r, c, rn, cn, self.targetsize, 0) recovered = recover_sample(sample, 'image', resized) out.append(recovered) if recovered['image'] is None: print('r,c,rn,cn, targetsize', r, c, rn, cn, self.targetsize) print('bbox', sample['bbox']) print('obbox', sample['obbox']) print('im_copy', im_copy.shape) if resized is None: print('resized is None') else: print('resized shape', resized.shape) # except AttributeError: # ValueError('Crop old {} {} {} {}, bbox {}, obbox {}, img h {}, img w {}'.format(r,c,rn,cn,u.get_nums_from_bbox(out[-1]['bbox']),u.get_nums_from_bbox(out[-1]['obbox']),out[-1]['image_height'],out[-1]['image_width'])) # print('Stop cropping',flush=True) # print_without_img(sample_list) return out
def test_data_augmentations(no_imgs=10): checkpoint_dir = os.path.abspath('../checkpoints/test_steps_04') if not os.path.isdir(checkpoint_dir): os.makedirs(checkpoint_dir) # print('------------------------') train_loader, _ = get_dataloaders(checkpoint_dir, rsyncing=False, selective_sampling=False, warmup_trainer=None, batch_size=10, num_workers=0, data_aug_vec=[0.5, 0.25, 0.5, 0.5]) # print('------------------------') for i, batch in enumerate(train_loader): print('Random {}'.format(i)) # print('Len batch', len(batch['has_lesion_batch'])) if batch['bbox_batch'][0]: r, c, rn, cn = u.get_nums_from_bbox(batch['bbox_batch'][0]) ro, co, rno, cno = u.get_nums_from_bbox(batch['obbox_batch'][0]) bbox_flag = True else: r = c = rn = cn = ro = co = rno = cno = 0 bbox_flag = False save_image_with_bb(batch, os.path.join(checkpoint_dir, 'rnd_{}.png'.format(i)), bbox_flag=bbox_flag, r=r, c=c, rn=rn, cn=cn, ro=ro, co=co, rno=rno, cno=cno) # print('-+-+-+------------------') if i >= no_imgs: break
def __call__(self, sample_list): """ Args: sample: image, bbox, label Returns: list: new bbox """ out = [] # print('Start Translation',flush=True) # print_without_img(sample_list) for sample in sample_list: if self.cat or sample['has_lesion'] == 1: r, c, rn, cn = u.get_nums_from_bbox(sample['bbox']) if np.random.rand() < self.prob: shift_r = np.random.randint(-self.pixel_range, self.pixel_range + 1) else: shift_r = 0 if np.random.rand() < self.prob: shift_c = np.random.randint(-self.pixel_range, self.pixel_range + 1) else: shift_c = 0 # if np.random.rand() < self.prob: # shift_r = -10 # shift_c = 7 # else: # shift_r = 0 # shift_c = 0 new_r = r + shift_r new_c = c + shift_c new_r, new_c, new_rn, new_cn = check_bbox(new_r, new_c, rn, cn, sample['image_height'], sample['image_width']) new_bbox = u.create_bbox(new_r, new_c, new_rn, new_cn) else: new_bbox = None out.append(recover_sample(sample, 'bbox', new_bbox)) try: s = out[-1]['image'].shape except AttributeError: raise ValueError( 'Transform old {} {} {} {} new {} {} {} {}, bbox {}, obbox {}'.format(r, c, rn, cn, new_r, new_c, new_rn, new_cn, u.get_nums_from_bbox( out[-1]['bbox']), u.get_nums_from_bbox( sample['obbox']))) return out
def save_orig_with_bbs(checkpoint_dir, img_id, big_org_img, bboxes): # big_org_img = np.swapaxes(big_org_img, 0, 1) # big_org_img = np.swapaxes(big_org_img, 1, 2) # big_org_img = np.float32(big_org_img *255) # path_to_file = os.path.join(checkpoint_dir, '{}_orig.jpg'.format(img_id)) # cv2.imwrite(path_to_file, full) big_org_img = np.swapaxes(big_org_img, 0, 1) big_org_img = np.swapaxes(big_org_img, 1, 2) sizes = np.shape(big_org_img) height = float(sizes[0]) width = float(sizes[1]) fig = plt.figure() fig.set_size_inches(width / height, 1, forward=False) ax = plt.Axes(fig, [0., 0., 1., 1.]) ax.set_axis_off() fig.add_axes(ax) ax.imshow(big_org_img) for bb in bboxes: if bb: r, c, rn, cn = get_nums_from_bbox(bb) # Create a Rectangle patch edgecol = 'g' rect = patches.Rectangle((c, r), cn, rn, linewidth=0.25, edgecolor=edgecol, facecolor='none') # Add the patch to the Axes ax.add_patch(rect) plt.savefig(os.path.join(checkpoint_dir, '{}_orig.jpg'.format(img_id)), dpi=height) plt.close()
def __call__(self, sample_list): """ Args: sample: image, bbox, label Returns: list: new bbox """ out = [] for sample in sample_list: if sample['has_lesion'] == 1 or self.cat: if np.random.rand() < self.prob: r, c, rn, cn = u.get_nums_from_bbox(sample['bbox']) factor = np.random.uniform(1 - self.max_percentage, 1 + self.max_percentage) # if self.max_percentage == 0.2: # factor = 1.1 # else: # factor = 0.9 new_rn = int(rn * factor) new_cn = int(cn * factor) diff_r = new_rn - rn diff_c = new_cn - cn # Get direction of scaling rnd_num = np.random.rand() # left up if rnd_num < 0.2: new_r = r - diff_r new_c = c - diff_c # left elif rnd_num < 0.4: new_r = r new_c = c - diff_c # up elif rnd_num < 0.6: new_r = r - diff_r new_c = c # right down elif rnd_num < 0.8: new_r = r new_c = c # center else: new_r = int(r - diff_r // 2) new_c = int(c - diff_c // 2) new_r, new_c, new_rn, new_cn = check_bbox(new_r, new_c, new_rn, new_cn, sample['image_height'], sample['image_width']) new_bbox = u.create_bbox(new_r, new_c, new_rn, new_cn) else: new_bbox = sample['bbox'] else: new_bbox = None out.append(recover_sample(sample, 'bbox', new_bbox)) try: s = out[-1]['image'].shape except AttributeError: raise ValueError( 'Scale old {} {} {} {} new {} {} {} {}, bbox {}, obbox {}, img h {}, img w {}'.format(r, c, rn, cn, new_r, new_c, new_rn, new_cn, u.get_nums_from_bbox( out[-1][ 'bbox']), u.get_nums_from_bbox( out[-1][ 'obbox']), out[-1][ 'image_height'], out[-1][ 'image_width'])) return out
def __call__(self, sample_list): # print('Start quadrupling',flush=True) # print_without_img(sample_list) for sample in sample_list: if sample['has_lesion'] == 1: # print('-----') # print('bbox',sample['obbox']) r, c, rn, cn = u.get_nums_from_bbox(sample['bbox']) ## Create even smaller new bbox row_lim = int(r - (1 / 4) * rn) if row_lim < 1: row_lim = 1 if r <= 1: r = 2 col_lim = int(c - (1 / 4) * cn) if col_lim < 1: col_lim = 1 if c <= 1: c = 2 try: new_r = np.random.randint(row_lim, r) new_c = np.random.randint(col_lim, c) except ValueError: raise ValueError( 'r {}, c {}, rn {}, cn {}, row_lim {}, col_lim {}'.format(r, c, rn, cn, row_lim, col_lim)) new_rn = int((5 / 4) * rn) new_cn = int((5 / 4) * cn) # print('Very very small before check',new_r,new_c,new_rn,new_cn) # print('height',sample['image_height'],'width',sample['image_width']) new_r, new_c, new_rn, new_cn = check_bbox(new_r, new_c, new_rn, new_cn, sample['image_height'], sample['image_width']) # print('Very very small after check',new_r,new_c,new_rn,new_cn) new_bbox_very_very_small = u.create_bbox(new_r, new_c, new_rn, new_cn) ## Create smallest new bbox row_lim = int(r - (1 / 2) * rn) if row_lim < 1: row_lim = 1 if r <= 1: r = 2 col_lim = int(c - (1 / 2) * cn) if col_lim < 1: col_lim = 1 if c <= 1: c = 2 try: new_r = np.random.randint(row_lim, r) new_c = np.random.randint(col_lim, c) except ValueError: raise ValueError( 'r {}, c {}, rn {}, cn {}, row_lim {}, col_lim {}'.format(r, c, rn, cn, row_lim, col_lim)) new_rn = int((3 / 2) * rn) new_cn = int((3 / 2) * cn) # print('Very small before check',new_r,new_c,new_rn,new_cn) # print('height',sample['image_height'],'width',sample['image_width']) new_r, new_c, new_rn, new_cn = check_bbox(new_r, new_c, new_rn, new_cn, sample['image_height'], sample['image_width']) # print('Very small after check',new_r,new_c,new_rn,new_cn) new_bbox_very_small = u.create_bbox(new_r, new_c, new_rn, new_cn) ## Create smaller new bbox row_lim = r - rn if row_lim < 1: row_lim = 1 if r <= 1: r = 2 col_lim = c - cn if col_lim < 1: col_lim = 1 if c <= 1: c = 2 try: new_r = np.random.randint(row_lim, r) new_c = np.random.randint(col_lim, c) except ValueError: raise ValueError( 'r {}, c {}, rn {}, cn {}, row_lim {}, col_lim {}'.format(r, c, rn, cn, row_lim, col_lim)) new_rn = 2 * rn new_cn = 2 * cn # print('Small before check',new_r,new_c,new_rn,new_cn) # print('height',sample['image_height'],'width',sample['image_width']) new_r, new_c, new_rn, new_cn = check_bbox(new_r, new_c, new_rn, new_cn, sample['image_height'], sample['image_width']) # print('Small after check',new_r,new_c,new_rn,new_cn) new_bbox_small = u.create_bbox(new_r, new_c, new_rn, new_cn) ## Create larger new bbox row_lim = r - 2 * rn if row_lim < 1: row_lim = 1 if r <= 1: r = 2 col_lim = c - 2 * cn if col_lim < 1: col_lim = 1 if c <= 1: c = 2 try: new_r = np.random.randint(row_lim, r) new_c = np.random.randint(col_lim, c) except ValueError: raise ValueError( 'r {}, c {}, rn {}, cn {}, row_lim {}, col_lim {}'.format(r, c, rn, cn, row_lim, col_lim)) new_rn = 3 * rn new_cn = 3 * cn # print('Big before check',new_r,new_c,new_rn,new_cn) # print('height',sample['image_height'],'width',sample['image_width']) new_r, new_c, new_rn, new_cn = check_bbox(new_r, new_c, new_rn, new_cn, sample['image_height'], sample['image_width']) # print('Big after check',new_r,new_c,new_rn,new_cn) new_bbox_big = u.create_bbox(new_r, new_c, new_rn, new_cn) else: new_bbox_very_very_small = None new_bbox_very_small = None new_bbox_small = None new_bbox_big = None # print('orig',sample['obbox']) # print('very small',new_bbox_very_small) # print('small',new_bbox_small) # print('big',new_bbox_big) return quintuple_sample(sample, new_bbox_very_very_small, new_bbox_very_small, new_bbox_small, new_bbox_big)
def render(self, mode='human', close=False, with_axis=True, with_state=False): """Renders the environment. The set of supported modes varies per environment. (And some environments do not support rendering at all.) By convention, if mode is: - figure: reutrns matplotlib figure - human: render to the current display or terminal and return nothing. Usually for human consumption. - rgb_array: Return an numpy.nxfdarray with shape (x, y, 3), representing RGB values for an x-by-y pixel image, suitable for turning into a video. - ansi: Return a string (str) or StringIO.StringIO containing a terminal-style text representation. The text can include newlines and ANSI escape sequences (e.g. for colors). Note: Make sure that your class's metadata 'render.modes' key includes the list of supported modes. It's recommended to call super() in implementations to use the functionality of this method. Args: mode (str): the mode to render with close (bool): close all open renderings Example: class MyEnv(Env): metadata = {'render.modes': ['human', 'rgb_array']} def render(self, mode='human'): if mode == 'rgb_array': return np.array(...) # return RGB frame suitable for video elif mode is 'human': ... # pop up a window and render else: super(MyEnv, self).render(mode=mode) # just raise an exception """ if mode == 'human': filename = 'dummy.png' cm = 'binary' lwidth = 2 image = self.get_original_img() r, c, rn, cn = get_nums_from_bbox(self.get_current_bb()) # print(r, c, rn, cn) ro, co, rno, cno = get_nums_from_bbox(self.get_original_bb()) # print(ro, co, rno, cno) if with_state: fig, ax = plt.subplots(1, 2) else: fig, ax = plt.subplots(1, 1) if not with_axis: if with_state: ax[0].set_axis_off() else: ax.set_axis_off() # Create a Rectangle patch edgecol = 'g' rect_gt = patches.Rectangle((co, ro), cno, rno, linewidth=lwidth, edgecolor=edgecol, facecolor='none') # Add the patch to the Axes edgecol = 'r' rect_current = patches.Rectangle((c, r), cn, rn, linewidth=lwidth, edgecolor=edgecol, facecolor='none') # Add the patch to the Axes if with_state: # print('MIN im', np.min(image)) # print('MAX im', np.max(image)) # print('MEAN im', np.mean(image)) ax[0].imshow(get_img_imshow(image), cmap=cm) ax[0].add_patch(rect_gt) ax[0].add_patch(rect_current) else: ax.imshow(get_img_imshow(image), cmap=cm) ax.add_patch(rect_gt) ax.add_patch(rect_current) if with_state: ax[1].set_axis_off() # print('MIN', torch.min(self.get_current_state())) # print('MAX', torch.max(self.get_current_state())) # print('MEAN', torch.mean(self.get_current_state())) ax[1].imshow(get_img_imshow(self.get_current_state()), cmap=cm) plt.savefig(filename, bbox_inches='tight') img = imageio.imread(filename) # print(img.shape) os.remove(filename) return img elif mode == 'figure': cm = 'binary' lwidth = 2 image = self.get_original_img() r, c, rn, cn = get_nums_from_bbox(self.get_current_bb()) # print(r, c, rn, cn) ro, co, rno, cno = get_nums_from_bbox(self.get_original_bb()) # print(ro, co, rno, cno) fig, ax = plt.subplots(1, 1) if not with_axis: ax.set_axis_off() ax.imshow(get_img_imshow(image), cmap=cm) # Create a Rectangle patch edgecol = 'g' rect = patches.Rectangle((co, ro), cno, rno, linewidth=lwidth, edgecolor=edgecol, facecolor='none') # Add the patch to the Axes ax.add_patch(rect) edgecol = 'r' rect = patches.Rectangle((c, r), cn, rn, linewidth=lwidth, edgecolor=edgecol, facecolor='none') # Add the patch to the Axes ax.add_patch(rect) return fig else: raise NotImplementedError
def test(cfg_dict, feat_model_string, rsyncing, toy=False): # checkpoint_dir, experiment_name = 'qnet', opti = 'optim.RMSprop', lr = 0.001, mom = 0.1, combi = False checkpoint_dir, log_dir, experiment_name = get_q_save_check_tensorB_expName( cfg_dict) opti_feat, lr, mom, _ = get_f_train_opti_lr_mom_epochs(cfg_dict) inputsize, hiddensize, outputsize = get_q_net_input_hidden_output(cfg_dict) _, _, _, double, combi, param_noise, recurrent, hidden_rec, _, _, _ = get_q_variants_oneImg_maxNumImgsT_maxNumImgsV_double_combi_paramNoise_recurrent_recSize_distReward_distFactor_hist( cfg_dict) _, max_steps, replaysize = get_q_hyper_cloneFreq_maxSteps_replaysize( cfg_dict) test_tau = get_q_explo_kappa_epochsEps_targetEps_tau_testTau_tauEpochs( cfg_dict)[4] cat = get_f_variants_selectiveS_checkPretrained_cat(cfg_dict)[2] featnet_checkpoint = get_f_save_check_tensorB_expName(cfg_dict)[0] print('-------------') print('feat_model_string', feat_model_string) print('-------------') feature_model = get_feature_model(feat_model_string, feat_model_string, load_pretrained=True, opti='optim.Adam', lr=lr, mom=mom, checkpoint_pretrained=featnet_checkpoint, cat=cat) if feat_model_string == 'auto' or feat_model_string == 'resnet': feature_model = res.ResNetFeatures(feature_model) else: feature_model = m.NetNoDecisionLayer(feature_model) if torch.cuda.is_available(): feature_model.cuda() model = get_q_model(combi, recurrent, toy, inputsize, hiddensize, outputsize, feature_model=feature_model, hidden_rec=hidden_rec) # HERE if torch.cuda.is_available(): model.cuda() criterion = nn.MSELoss() if combi and recurrent <= 0: if feat_model_string == 'auto' or feat_model_string == 'resnet': model_params = [{ 'params': model.conv1.parameters(), 'lr': lr / 10 }, { 'params': model.bn1.parameters(), 'lr': lr / 10 }, { 'params': model.relu.parameters(), 'lr': lr / 10 }, { 'params': model.maxpool.parameters(), 'lr': lr / 10 }, { 'params': model.layer1.parameters(), 'lr': lr / 10 }, { 'params': model.layer2.parameters(), 'lr': lr / 10 }, { 'params': model.layer3.parameters(), 'lr': lr / 10 }, { 'params': model.layer4.parameters(), 'lr': lr / 10 }, { 'params': model.qnet.parameters() }] else: model_params = [{ 'params': model.features.parameters(), 'lr': lr / 10 }, { 'params': model.qnet.parameters() }] elif combi and recurrent > 0: if toy: model_params = [ { 'params': model.features.parameters(), 'lr': lr / 10 }, { 'params': model.ll1.parameters() }, { 'params': model.ll2.parameters() }, { 'params': model.ll3.parameters() }, { 'params': model.relu2.parameters() }, { 'params': model.lstm.parameters() }, ] else: model_params = [ { 'params': model.conv1.parameters(), 'lr': lr / 10 }, { 'params': model.bn1.parameters(), 'lr': lr / 10 }, { 'params': model.relu.parameters(), 'lr': lr / 10 }, { 'params': model.maxpool.parameters(), 'lr': lr / 10 }, { 'params': model.layer1.parameters(), 'lr': lr / 10 }, { 'params': model.layer2.parameters(), 'lr': lr / 10 }, { 'params': model.layer3.parameters(), 'lr': lr / 10 }, { 'params': model.layer4.parameters(), 'lr': lr / 10 }, { 'params': model.ll1.parameters() }, { 'params': model.ll2.parameters() }, { 'params': model.ll3.parameters() }, { 'params': model.relu2.parameters() }, { 'params': model.lstm.parameters() }, ] else: model_params = model.parameters() optimizer = get_optimizer(model_params, opti_feat, lr, mom) print(model, flush=True) # checkpoint_filename = os.path.join(checkpoint_dir, 'checkpoint_{}.pth.tar'.format(experiment_name)) # checkpoint_filename = os.path.join(checkpoint_dir, 'model_best_{}.pth.tar'.format(experiment_name)) checkpoint_filename = os.path.join( checkpoint_dir, 'warmup_model_{}.pth.tar'.format(experiment_name)) print('Load checkpoint from {}'.format( os.path.abspath(checkpoint_filename))) ###### # TODO this if else should be before if os.path.exists(checkpoint_filename): if not os.path.isdir('{}/trajectories_test'.format(checkpoint_dir)): os.makedirs('{}/trajectories_test'.format(checkpoint_dir)) # Don't load optimizer, otherwise LR might be too low already (yes? TODO) model, _, initial_epoch = load_checkpoint(model, optimizer, filename=checkpoint_filename) if combi and recurrent <= 0: if feat_model_string == 'auto' or feat_model_string == 'resnet': model_params = [{ 'params': model.conv1.parameters(), 'lr': lr / 10 }, { 'params': model.bn1.parameters(), 'lr': lr / 10 }, { 'params': model.relu.parameters(), 'lr': lr / 10 }, { 'params': model.maxpool.parameters(), 'lr': lr / 10 }, { 'params': model.layer1.parameters(), 'lr': lr / 10 }, { 'params': model.layer2.parameters(), 'lr': lr / 10 }, { 'params': model.layer3.parameters(), 'lr': lr / 10 }, { 'params': model.layer4.parameters(), 'lr': lr / 10 }, { 'params': model.qnet.parameters() }] else: model_params = [{ 'params': model.features.parameters(), 'lr': lr / 10 }, { 'params': model.qnet.parameters() }] else: model_params = model.parameters() assert opti_feat in ['optim.Adam', 'optim.SGD', 'optim.RMSprop'] print('Using optimizer {}'.format(opti_feat)) if opti_feat == 'optim.Adam': optimizer = eval(opti_feat)(model_params, lr=lr) else: optimizer = eval(opti_feat)(model_params, lr=lr, momentum=mom) model_path = 'model_best_{}.pth.tar'.format(experiment_name) print('Loading model checkpointed at epoch {}'.format(initial_epoch)) print('Get val env', flush=True) val_env = get_val_env_only(cfg_dict, feature_model, rsyncing=rsyncing, toy=toy, f_one=True) # val_env = get_val_env_only(cfg_dict, feature_model, rsyncing=rsyncing, toy=toy,f_one=False) warmup_trainer = QNetTrainer(cfg_dict, model, val_env, experiment_name=experiment_name, log_dir='default', checkpoint_dir=checkpoint_dir, checkpoint_filename=model_path, for_testing=True, tau_schedule=0, recurrent=recurrent) warmup_trainer.compile(loss=criterion, optimizer=optimizer) print('Evaluate', flush=True) val_metrics_arr, trajectory_all_imgs, triggered_all_imgs, Q_s_all_imgs, all_imgs, all_gt, \ actions_all_imgs = warmup_trainer.evaluate(val_env, 0, save_trajectory=True) print('Val_metrics_arr', val_metrics_arr) width = 2 steps_until_detection = [] print('Save', flush=True) if len(trajectory_all_imgs) > 1: # print('i:',len(trajectory_all_imgs),flush=True) for i, img in enumerate(trajectory_all_imgs): # print('j:',len(trajectory_all_imgs[i]),flush=True) orig_img = all_imgs[i] if all_gt[i] is None: orig_r = 0 orig_c = 0 orig_rn = 0 orig_cn = 0 else: orig_r, orig_c, orig_rn, orig_cn = get_nums_from_bbox( all_gt[i]) for j, starts in enumerate(img): # print('k:',len(trajectory_all_imgs[i][j]),flush=True) for k, step in enumerate(starts): r, c, rn, cn = step # print(triggered_all_imgs) # print(Q_s_all_imgs) # print(len(Q_s_all_imgs)) # print(len(Q_s_all_imgs[i])) # print(len(Q_s_all_imgs[i][j])) if triggered_all_imgs[i][j] == 1 and k == ( len(starts) - 1): steps_until_detection.append(k) if i < 10: save_image_with_orig_plus_current_bb( orig_img, '{}/trajectories_test/{}_{}_{}_trigger.png' .format(checkpoint_dir, i, j, k), bbox_flag=True, r=r, c=c, rn=rn, cn=cn, ro=orig_r, co=orig_c, rno=orig_rn, cno=orig_cn, lwidth=width, Q_s=Q_s_all_imgs[i][j][k], eps=-1, action=actions_all_imgs[i][j][k]) else: if i < 10: save_image_with_orig_plus_current_bb( orig_img, '{}/trajectories_test/{}_{}_{}.png'.format( checkpoint_dir, i, j, k), bbox_flag=True, r=r, c=c, rn=rn, cn=cn, ro=orig_r, co=orig_c, rno=orig_rn, cno=orig_cn, lwidth=width, Q_s=Q_s_all_imgs[i][j][k], eps=-1, action=actions_all_imgs[i][j][k]) else: orig_img = all_imgs[0] orig_r, orig_c, orig_rn, orig_cn = get_nums_from_bbox(all_gt[0]) for j, starts in enumerate(trajectory_all_imgs[0]): for k, step in enumerate(starts): r, c, rn, cn = step if triggered_all_imgs[0][j] == 1 and k == (len(starts) - 1): steps_until_detection.append(k) save_image_with_orig_plus_current_bb( orig_img, '{}/trajectories_test/0_{}_{}_trigger.png'.format( checkpoint_dir, j, k), bbox_flag=True, r=r, c=c, rn=rn, cn=cn, ro=orig_r, co=orig_c, rno=orig_rn, cno=orig_cn, lwidth=width, Q_s=Q_s_all_imgs[0][j][k], action=actions_all_imgs[0][j][k]) else: save_image_with_orig_plus_current_bb( orig_img, '{}/trajectories_test/0_{}_{}_trigger.png'.format( checkpoint_dir, j, k), bbox_flag=True, r=r, c=c, rn=rn, cn=cn, ro=orig_r, co=orig_c, rno=orig_rn, cno=orig_cn, lwidth=width, Q_s=Q_s_all_imgs[0][j][k], action=actions_all_imgs[0][j][k]) pkl.dump( steps_until_detection, open( '{}/trajectories_test/steps_until_detection.pkl'.format( checkpoint_dir), 'wb')) pkl.dump( val_metrics_arr, open( '{}/trajectories_test/val_metrics_arr.pkl'.format( checkpoint_dir), 'wb')) else: print('For testing, checkpoint filename has to exist')