def create_segmenter(encoder, decoder_config): with torch.no_grad(): decoder = Decoder( inp_sizes=encoder.out_sizes, num_classes=NUM_CLASSES[args.dataset_type][0], config=decoder_config, agg_size=48, #args.agg_cell_size, what's the fxxk aux_cell=True, #args.aux_cell, repeats=1) #args.sep_repeats) # Fuse encoder and decoder segmenter = nn.DataParallel(Segmenter(encoder, decoder)).cuda() logger.info(" Created Segmenter, #PARAMS (Total, No AUX)={}".format( compute_params(segmenter))) return segmenter #, entropy, log_prob
def create_segmenter(encoder): with torch.no_grad(): decoder_config, entropy, log_prob = agent.controller.sample() decoder = Decoder(inp_sizes=encoder.out_sizes, num_classes=args.num_classes[0], config=decoder_config, agg_size=args.agg_cell_size, aux_cell=args.aux_cell, repeats=args.sep_repeats) # Fuse encoder and decoder segmenter = nn.DataParallel(Segmenter(encoder, decoder)).cuda() logger.info(" Created Segmenter, #PARAMS (Total, No AUX)={}".format( compute_params(segmenter))) return segmenter, decoder_config, entropy, log_prob
def create_segmenter(encoder): if args.ctrl_version == "cvpr": from nn.micro_decoders import MicroDecoder as Decoder elif args.ctrl_version == "wacv": from nn.micro_decoders import TemplateDecoder as Decoder with torch.no_grad(): decoder_config, entropy, log_prob = agent.controller.sample() decoder = Decoder( inp_sizes=encoder.out_sizes, num_classes=args.num_classes[0], config=decoder_config, agg_size=args.agg_cell_size, aux_cell=args.aux_cell, repeats=args.sep_repeats, ) # Fuse encoder and decoder segmenter = nn.DataParallel(Segmenter(encoder, decoder)).cuda() logger.info(" Created Segmenter, #PARAMS (Total, No AUX)={}".format( compute_params(segmenter))) return segmenter, decoder_config, entropy, log_prob
def main(): # Set-up experiment args = get_arguments() logger = logging.getLogger(__name__) logger.debug(args) exp_name = time.strftime("%H_%M_%S") dir_name = "{}/{}".format(args.summary_dir, exp_name) if not os.path.exists(dir_name): os.makedirs(dir_name) arch_writer = open("{}/genotypes.out".format(dir_name), "w") logger.info(" Running Experiment {}".format(exp_name)) args.num_tasks = len(args.num_classes) segm_crit = nn.NLLLoss2d(ignore_index=255).cuda() # Set-up random seeds torch.manual_seed(args.random_seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.random_seed) np.random.seed(args.random_seed) random.seed(args.random_seed) # Initialise encoder encoder = create_encoder(ctrl_version=args.ctrl_version, ) logger.info(" Loaded Encoder with #TOTAL PARAMS={:3.2f}M".format( compute_params(encoder)[0] / 1e6)) # Generate teacher if any kd_net = None kd_crit = None if args.do_kd: from kd.rf_lw.model_lw_v2 import rf_lw152 as kd_model kd_crit = nn.MSELoss().cuda() kd_net = (kd_model(pretrained=True, num_classes=args.num_classes[0]).cuda().eval()) logger.info(" Loaded teacher, #TOTAL PARAMS={:3.2f}M".format( compute_params(kd_net)[0] / 1e6)) # Generate controller / RL-agent agent = create_agent( enc_num_layers=len(encoder.out_sizes), num_ops=args.num_ops, num_agg_ops=args.num_agg_ops, lstm_hidden_size=args.lstm_hidden_size, lstm_num_layers=args.lstm_num_layers, dec_num_cells=args.dec_num_cells, cell_num_layers=args.cell_num_layers, cell_max_repeat=args.cell_max_repeat, cell_max_stride=args.cell_max_stride, ctrl_lr=args.ctrl_lr, ctrl_baseline_decay=args.ctrl_baseline_decay, ctrl_agent=args.ctrl_agent, ctrl_version=args.ctrl_version, ) logger.info(" Loaded Controller, #TOTAL PARAMS={:3.2f}M".format( compute_params(agent.controller)[0] / 1e6)) def create_segmenter(encoder): if args.ctrl_version == "cvpr": from nn.micro_decoders import MicroDecoder as Decoder elif args.ctrl_version == "wacv": from nn.micro_decoders import TemplateDecoder as Decoder with torch.no_grad(): decoder_config, entropy, log_prob = agent.controller.sample() decoder = Decoder( inp_sizes=encoder.out_sizes, num_classes=args.num_classes[0], config=decoder_config, agg_size=args.agg_cell_size, aux_cell=args.aux_cell, repeats=args.sep_repeats, ) # Fuse encoder and decoder segmenter = nn.DataParallel(Segmenter(encoder, decoder)).cuda() logger.info(" Created Segmenter, #PARAMS (Total, No AUX)={}".format( compute_params(segmenter))) return segmenter, decoder_config, entropy, log_prob # Sample first configuration segmenter, decoder_config, entropy, log_prob = create_segmenter(encoder) del encoder # Create dataloaders train_loader, val_loader, do_search = create_loaders(args) # Initialise task performance measurers task_ps = [[ TaskPerformer(maxval=0.01, delta=0.9) for _ in range(args.num_segm_epochs[idx] // args.val_every[idx]) ] for idx, _ in enumerate(range(args.num_tasks))] # Restore from previous checkpoint if any best_val, epoch_start = load_ckpt(args.ckpt_path, {"agent": agent}) # Saver: keeping checkpoint with best validation score (a.k.a best reward) saver = Saver( args=vars(args), ckpt_dir=args.snapshot_dir, best_val=best_val, condition=lambda x, y: x > y, ) logger.info(" Pre-computing data for task0") Xy_train = populate_task0(segmenter, train_loader, kd_net, args.n_task0, args.do_kd) if args.do_kd: del kd_net logger.info(" Training Process Starts") for epoch in range(epoch_start, args.num_epochs): reward = 0.0 start = time.time() torch.cuda.empty_cache() logger.info(" Training Segmenter, Arch {}".format(str(epoch))) stop = False for task_idx in range(args.num_tasks): if stop: break torch.cuda.empty_cache() # Change dataloader train_loader.batch_sampler.batch_size = args.batch_size[task_idx] for loader in [train_loader, val_loader]: try: loader.dataset.set_config( crop_size=args.crop_size[task_idx], shorter_side=args.shorter_side[task_idx], ) except AttributeError: # for subset loader.dataset.dataset.set_config( crop_size=args.crop_size[task_idx], resize_side=args.resize_side[task_idx], ) logger.info(" Training Task {}".format(str(task_idx))) # Optimisers optim_enc, optim_dec = create_optimisers( args.enc_optim, args.dec_optim, args.enc_lr[task_idx], args.dec_lr[task_idx], args.enc_mom[task_idx], args.dec_mom[task_idx], args.enc_wd[task_idx], args.dec_wd[task_idx], segmenter.module.encoder.parameters(), segmenter.module.decoder.parameters(), ) avg_param = init_polyak( args.do_polyak, segmenter.module.decoder if task_idx == 0 else segmenter) for epoch_segm in range(args.num_segm_epochs[task_idx]): if task_idx == 0: train_task0( Xy_train, segmenter, optim_dec, epoch_segm, segm_crit, kd_crit, args.batch_size[0], args.freeze_bn[0], args.do_kd, args.kd_coeff, args.dec_grad_clip, args.do_polyak, avg_param=avg_param, polyak_decay=0.9, aux_weight=args.dec_aux_weight, ) else: train_segmenter( segmenter, train_loader, optim_enc, optim_dec, epoch_segm, segm_crit, args.freeze_bn[1], args.enc_grad_clip, args.dec_grad_clip, args.do_polyak, args.print_every, aux_weight=args.dec_aux_weight, avg_param=avg_param, polyak_decay=0.99, ) apply_polyak( args.do_polyak, segmenter.module.decoder if task_idx == 0 else segmenter, avg_param, ) if (epoch_segm + 1) % (args.val_every[task_idx]) == 0: logger.info( " Validating Segmenter, Arch {}, Task {}".format( str(epoch), str(task_idx))) task_miou = validate( segmenter, val_loader, epoch, epoch_segm, num_classes=args.num_classes[task_idx], print_every=args.print_every, omit_classes=args.val_omit_classes, ) # Verifying if we are continuing training this architecture. c_task_ps = task_ps[task_idx][(epoch_segm + 1) // args.val_every[task_idx] - 1] if c_task_ps.step(task_miou): continue else: logger.info(" Interrupting") stop = True break reward = task_miou if do_search: logger.info(" Training Controller") sample = ((decoder_config), reward, entropy, log_prob) train_agent(agent, sample) # Log this epoch _, params = compute_params(segmenter) logger.info(" Decoder: {}".format(decoder_config)) # Save controller params saver.save(reward, { "agent": agent.state_dict(), "epoch": epoch }, logger) # Save genotypes epoch_time = (time.time() - start) / sum( args.num_segm_epochs[:(task_idx + 1)]) arch_writer.write( "reward: {:.4f}, epoch: {}, params: {}, epoch_time: {:.4f}, genotype: {}\n" .format(reward, epoch, params, epoch_time, decoder_config)) arch_writer.flush() # Sample a new architecture del segmenter encoder = create_encoder(ctrl_version=args.ctrl_version, ) segmenter, decoder_config, entropy, log_prob = create_segmenter( encoder) del encoder
def main(): # Set-up experiment args = get_arguments() logger = logging.getLogger(__name__) exp_name = time.strftime('%H_%M_%S') # dir_name = '{}/{}'.format(args.summary_dir, exp_name) # if not os.path.exists(dir_name): # os.makedirs(dir_name) # arch_writer = open('{}/genotypes.out'.format(dir_name), 'w') logger.info(" Running Experiment {}".format(exp_name)) args.num_tasks = len(NUM_CLASSES[args.dataset_type]) segm_crit = nn.NLLLoss2d(ignore_index=255).cuda() # Set-up random seeds torch.manual_seed(args.random_seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.random_seed) np.random.seed(args.random_seed) random.seed(args.random_seed) # Create dataloaders train_loader, val_loader, do_search = create_loaders(args) def create_segmenter(encoder, decoder_config): with torch.no_grad(): decoder = Decoder( inp_sizes=encoder.out_sizes, num_classes=NUM_CLASSES[args.dataset_type][0], config=decoder_config, agg_size=48, #args.agg_cell_size, what's the fxxk aux_cell=True, #args.aux_cell, repeats=1) #args.sep_repeats) # Fuse encoder and decoder segmenter = nn.DataParallel(Segmenter(encoder, decoder)).cuda() logger.info(" Created Segmenter, #PARAMS (Total, No AUX)={}".format( compute_params(segmenter))) return segmenter #, entropy, log_prob for decoder_config in decoder_config_arry: # Initialise encoder encoder = create_encoder() logger.info(" Loaded Encoder with #TOTAL PARAMS={:3.2f}M".format( compute_params(encoder)[0] / 1e6)) # Sample first configuration segmenter = create_segmenter(encoder, decoder_config) del encoder logger.info(" Loaded Encoder with #TOTAL PARAMS={:3.2f}M".format( compute_params(segmenter)[0] / 1e6)) # Saver: keeping checkpoint with best validation score (a.k.a best reward) now = datetime.datetime.now() snapshot_dir = args.snapshot_dir + '_train_' + args.dataset_type + "_{:%Y%m%dT%H%M}".format( now) seg_saver = seg_Saver(ckpt_dir=snapshot_dir) arch_writer = open('{}/genotypes.out'.format(snapshot_dir), 'w') arch_writer.write('genotype: {}\n'.format(decoder_config)) arch_writer.flush() logger.info(" Pre-computing data for task0") kd_net = None # stub the kd logger.info(" Training Process Starts") for task_idx in range(args.num_tasks): #0,1 if task_idx == 0: continue torch.cuda.empty_cache() # Change dataloader train_loader.batch_sampler.batch_size = BATCH_SIZE[ args.dataset_type][task_idx] logger.info(" Training Task {}".format(str(task_idx))) # Optimisers optim_enc, optim_dec = create_optimisers( args.optim_enc, args.optim_dec, args.lr_enc[task_idx], args.lr_dec[task_idx], args.mom_enc[task_idx], args.mom_dec[task_idx], args.wd_enc[task_idx], args.wd_dec[task_idx], segmenter.module.encoder.parameters(), segmenter.module.decoder.parameters()) kd_crit = None #stub the kd for epoch_segm in range(TRAIN_EPOCH_NUM[args.dataset_type] [task_idx]): # [5,1] [20,8] final_loss = train_segmenter( segmenter, #train the segmenter end to end onece train_loader, optim_enc, optim_dec, epoch_segm, segm_crit, args.freeze_bn[1], args.enc_grad_clip, args.dec_grad_clip, args.do_polyak, args.print_every, aux_weight=args.dec_aux_weight, # avg_param=avg_param, polyak_decay=0.99) seg_saver.save(final_loss, segmenter.state_dict(), logger) #stub to 1 # validat segmenter.eval() data_file = dataset_dirs[args.dataset_type]['VAL_LIST'] data_dir = dataset_dirs[args.dataset_type]['VAL_DIR'] with open(data_file, 'rb') as f: datalist = f.readlines() try: datalist = [ (k, v) for k, v, _ in \ map(lambda x: x.decode('utf-8').strip('\n').split('\t'), datalist)] except ValueError: # Adhoc for test. datalist = [ (k, k) for k in map(lambda x: x.decode('utf-8').strip('\n'), datalist) ] imgs_all = [ os.path.join(data_dir, datalist[i][0]) for i in range(0, len(datalist)) ] msks_all = [ os.path.join(data_dir, datalist[i][1]) for i in range(0, len(datalist)) ] validate_output_dir = os.path.join( dataset_dirs[args.dataset_type]['VAL_DIR'], 'validate_output') validate_gt_dir = os.path.join( dataset_dirs[args.dataset_type]['VAL_DIR'], 'validate_gt') if not os.path.exists(validate_output_dir): os.makedirs(validate_output_dir) else: shutil.rmtree(validate_output_dir) os.makedirs(validate_output_dir) if not os.path.exists(validate_gt_dir): os.makedirs(validate_gt_dir) else: shutil.rmtree(validate_gt_dir) os.makedirs(validate_gt_dir) # validate_color_dir = os.path.join(dataset_dirs[args.dataset_type]['VAL_DIR'], 'validate_output_color') for i, img_path in enumerate(imgs_all): # logger.info("Testing image:{}".format(img_path)) img = np.array(Image.open(img_path)) msk = np.array(Image.open(msks_all[i])) orig_size = img.shape[:2][::-1] img_inp = torch.tensor(prepare_img(img).transpose( 2, 0, 1)[None]).float().to(device) segm = segmenter( img_inp)[0].squeeze().data.cpu().numpy().transpose( (1, 2, 0)) # 47*63*21 if args.dataset_type == 'celebA': # msk = cv2.resize(msk,segm.shape[0:2],interpolation=cv2.INTER_NEAREST) segm = cv2.resize(segm, orig_size, interpolation=cv2.INTER_CUBIC) # 375*500*21 else: segm = cv2.resize(segm, orig_size, interpolation=cv2.INTER_CUBIC) # 375*500*21 segm = segm.argmax(axis=2).astype(np.uint8) image_name = img_path.split('/')[-1].split('.')[0] # image_name = val_loader.dataset.datalist[i][0].split('/')[1].split('.')[0] # cv2.imwrite(os.path.join(validate_color_dir, "{}.png".format(image_name)), color_array[segm]) # cv2.imwrite(os.path.join(validate_gt_dir, "{}.png".format(image_name)), color_array[msk]) cv2.imwrite( os.path.join(validate_output_dir, "{}.png".format(image_name)), segm) cv2.imwrite( os.path.join(validate_gt_dir, "{}.png".format(image_name)), msk) if args.dataset_type == 'celebA': cal_f1_score_celebA(validate_gt_dir, validate_output_dir, arch_writer) # temp comment