# dataset_root = os.path.join('../datasets/data', opts['test_db']) dataset_root = os.path.join('../datasets/data', 'vot15') vid_folders = [] for filename in os.listdir(dataset_root): if os.path.isdir(os.path.join(dataset_root, filename)): vid_folders.append(filename) vid_folders.sort(key=str.lower) # all_precisions = [] save_root = args.save_result_images save_root_npy = args.save_result_npy for vid_folder in vid_folders: print('Loading {}...'.format(args.weight_file)) opts['num_videos'] = 1 net, domain_nets = adnet(opts, trained_file=args.weight_file, random_initialize_domain_specific=True, vid_index=args.vid_index) net.train() if args.cuda: net = nn.DataParallel(net) cudnn.benchmark = True if args.cuda: net = net.cuda() if args.save_result_images is not None: args.save_result_images = os.path.join(save_root, vid_folder) if not os.path.exists(args.save_result_images): os.mkdir(args.save_result_images) args.save_result_npy = os.path.join(save_root_npy, vid_folder)
def adnet_train_sl(args, opts): if torch.cuda.is_available(): if args.cuda: torch.set_default_tensor_type('torch.cuda.FloatTensor') if not args.cuda: print( "WARNING: It looks like you have a CUDA device, but aren't " + "using CUDA.\nRun with --cuda for optimal training speed.") torch.set_default_tensor_type('torch.FloatTensor') else: torch.set_default_tensor_type('torch.FloatTensor') if not os.path.exists(args.save_folder): os.mkdir(args.save_folder) if args.visualize: writer = SummaryWriter(log_dir=os.path.join('tensorboardx_log', args.save_file)) train_videos = get_train_videos(opts) opts['num_videos'] = len(train_videos['video_names']) net, domain_specific_nets = adnet(opts=opts, trained_file=args.resume, multidomain=args.multidomain) if args.cuda: net = nn.DataParallel(net) cudnn.benchmark = True net = net.cuda() if args.cuda: optimizer = optim.SGD([ {'params': net.module.base_network.parameters(), 'lr': 1e-4}, {'params': net.module.fc4_5.parameters()}, {'params': net.module.fc6.parameters()}, {'params': net.module.fc7.parameters()}], # as action dynamic is zero, it doesn't matter lr=1e-3, momentum=opts['train']['momentum'], weight_decay=opts['train']['weightDecay']) else: optimizer = optim.SGD([ {'params': net.base_network.parameters(), 'lr': 1e-4}, {'params': net.fc4_5.parameters()}, {'params': net.fc6.parameters()}, {'params': net.fc7.parameters()}], lr=1e-3, momentum=opts['train']['momentum'], weight_decay=opts['train']['weightDecay']) if args.resume: # net.load_weights(args.resume) checkpoint = torch.load(args.resume) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) net.train() if not args.resume: print('Initializing weights...') if args.cuda: scal = torch.Tensor([0.01]) # fc 4 nn.init.normal_(net.module.fc4_5[0].weight.data) net.module.fc4_5[0].weight.data = net.module.fc4_5[0].weight.data * scal.expand_as(net.module.fc4_5[0].weight.data) net.module.fc4_5[0].bias.data.fill_(0.1) # fc 5 nn.init.normal_(net.module.fc4_5[3].weight.data) net.module.fc4_5[3].weight.data = net.module.fc4_5[3].weight.data * scal.expand_as(net.module.fc4_5[3].weight.data) net.module.fc4_5[3].bias.data.fill_(0.1) # fc 6 nn.init.normal_(net.module.fc6.weight.data) net.module.fc6.weight.data = net.module.fc6.weight.data * scal.expand_as(net.module.fc6.weight.data) net.module.fc6.bias.data.fill_(0) # fc 7 nn.init.normal_(net.module.fc7.weight.data) net.module.fc7.weight.data = net.module.fc7.weight.data * scal.expand_as(net.module.fc7.weight.data) net.module.fc7.bias.data.fill_(0) else: scal = torch.Tensor([0.01]) # fc 4 nn.init.normal_(net.fc4_5[0].weight.data) net.fc4_5[0].weight.data = net.fc4_5[0].weight.data * scal.expand_as(net.fc4_5[0].weight.data ) net.fc4_5[0].bias.data.fill_(0.1) # fc 5 nn.init.normal_(net.fc4_5[3].weight.data) net.fc4_5[3].weight.data = net.fc4_5[3].weight.data * scal.expand_as(net.fc4_5[3].weight.data) net.fc4_5[3].bias.data.fill_(0.1) # fc 6 nn.init.normal_(net.fc6.weight.data) net.fc6.weight.data = net.fc6.weight.data * scal.expand_as(net.fc6.weight.data) net.fc6.bias.data.fill_(0) # fc 7 nn.init.normal_(net.fc7.weight.data) net.fc7.weight.data = net.fc7.weight.data * scal.expand_as(net.fc7.weight.data) net.fc7.bias.data.fill_(0) action_criterion = nn.CrossEntropyLoss() score_criterion = nn.CrossEntropyLoss() print('generating Supervised Learning dataset..') # dataset = SLDataset(train_videos, opts, transform= datasets_pos, datasets_neg = initialize_pos_neg_dataset(train_videos, opts, transform=ADNet_Augmentation(opts)) number_domain = opts['num_videos'] batch_iterators_pos = [] batch_iterators_neg = [] # calculating number of data len_dataset_pos = 0 len_dataset_neg = 0 for dataset_pos in datasets_pos: len_dataset_pos += len(dataset_pos) for dataset_neg in datasets_neg: len_dataset_neg += len(dataset_neg) epoch_size_pos = len_dataset_pos // opts['minibatch_size'] epoch_size_neg = len_dataset_neg // opts['minibatch_size'] epoch_size = epoch_size_pos + epoch_size_neg # 1 epoch, how many iterations print("1 epoch = " + str(epoch_size) + " iterations") max_iter = opts['numEpoch'] * epoch_size print("maximum iteration = " + str(max_iter)) data_loaders_pos = [] data_loaders_neg = [] for dataset_pos in datasets_pos: data_loaders_pos.append(data.DataLoader(dataset_pos, opts['minibatch_size'], num_workers=args.num_workers, shuffle=True, pin_memory=True)) for dataset_neg in datasets_neg: data_loaders_neg.append(data.DataLoader(dataset_neg, opts['minibatch_size'], num_workers=args.num_workers, shuffle=True, pin_memory=True)) epoch = args.start_epoch if epoch != 0 and args.start_iter == 0: start_iter = epoch * epoch_size else: start_iter = args.start_iter which_dataset = list(np.full(epoch_size_pos, fill_value=1)) which_dataset.extend(np.zeros(epoch_size_neg, dtype=int)) shuffle(which_dataset) which_domain = np.random.permutation(number_domain) action_loss = 0 score_loss = 0 # training loop for iteration in range(start_iter, max_iter): if args.multidomain: curr_domain = which_domain[iteration % len(which_domain)] else: curr_domain = 0 # if new epoch (not including the very first iteration) if (iteration != start_iter) and (iteration % epoch_size == 0): epoch += 1 shuffle(which_dataset) np.random.shuffle(which_domain) print('Saving state, epoch:', epoch) domain_specific_nets_state_dict = [] for domain_specific_net in domain_specific_nets: domain_specific_nets_state_dict.append(domain_specific_net.state_dict()) torch.save({ 'epoch': epoch, 'adnet_state_dict': net.state_dict(), 'adnet_domain_specific_state_dict': domain_specific_nets, 'optimizer_state_dict': optimizer.state_dict(), }, os.path.join(args.save_folder, args.save_file) + 'epoch' + repr(epoch) + '.pth') if args.visualize: writer.add_scalars('data/epoch_loss', {'action_loss': action_loss / epoch_size, 'score_loss': score_loss / epoch_size, 'total': (action_loss + score_loss) / epoch_size}, global_step=epoch) # reset epoch loss counters action_loss = 0 score_loss = 0 # if new epoch (including the first iteration), initialize the batch iterator # or just resuming where batch_iterator_pos and neg haven't been initialized if iteration % epoch_size == 0 or len(batch_iterators_pos) == 0 or len(batch_iterators_neg) == 0: # create batch iterator for data_loader_pos in data_loaders_pos: batch_iterators_pos.append(iter(data_loader_pos)) for data_loader_neg in data_loaders_neg: batch_iterators_neg.append(iter(data_loader_neg)) # if not batch_iterators_pos[curr_domain]: # # create batch iterator # batch_iterators_pos[curr_domain] = iter(data_loaders_pos[curr_domain]) # # if not batch_iterators_neg[curr_domain]: # # create batch iterator # batch_iterators_neg[curr_domain] = iter(data_loaders_neg[curr_domain]) # load train data if which_dataset[iteration % len(which_dataset)]: # if positive try: images, bbox, action_label, score_label, vid_idx = next(batch_iterators_pos[curr_domain]) except StopIteration: batch_iterators_pos[curr_domain] = iter(data_loaders_pos[curr_domain]) images, bbox, action_label, score_label, vid_idx = next(batch_iterators_pos[curr_domain]) else: try: images, bbox, action_label, score_label, vid_idx = next(batch_iterators_neg[curr_domain]) except StopIteration: batch_iterators_neg[curr_domain] = iter(data_loaders_neg[curr_domain]) images, bbox, action_label, score_label, vid_idx = next(batch_iterators_neg[curr_domain]) # TODO: check if this requires grad is really false like in Variable if args.cuda: images = torch.Tensor(images.cuda()) bbox = torch.Tensor(bbox.cuda()) action_label = torch.Tensor(action_label.cuda()) score_label = torch.Tensor(score_label.float().cuda()) else: images = torch.Tensor(images) bbox = torch.Tensor(bbox) action_label = torch.Tensor(action_label) score_label = torch.Tensor(score_label) t0 = time.time() # load ADNetDomainSpecific with video index if args.cuda: net.module.load_domain_specific(domain_specific_nets[curr_domain]) else: net.load_domain_specific(domain_specific_nets[curr_domain]) # forward action_out, score_out = net(images) # backprop optimizer.zero_grad() if which_dataset[iteration % len(which_dataset)]: # if positive action_l = action_criterion(action_out, torch.max(action_label, 1)[1]) else: action_l = torch.Tensor([0]) score_l = score_criterion(score_out, score_label.long()) loss = action_l + score_l loss.backward() optimizer.step() action_loss += action_l.item() score_loss += score_l.item() # save the ADNetDomainSpecific back to their module if args.cuda: domain_specific_nets[curr_domain].load_weights_from_adnet(net.module) else: domain_specific_nets[curr_domain].load_weights_from_adnet(net) t1 = time.time() if iteration % 10 == 0: print('Timer: %.4f sec.' % (t1 - t0)) print('iter ' + repr(iteration) + ' || Loss: %.4f ||' % (loss.data.item()), end=' ') if args.visualize and args.send_images_to_visualization: random_batch_index = np.random.randint(images.size(0)) writer.add_image('image', images.data[random_batch_index].cpu().numpy(), random_batch_index) if args.visualize: writer.add_scalars('data/iter_loss', {'action_loss': action_l.item(), 'score_loss': score_l.item(), 'total': (action_l.item() + score_l.item())}, global_step=iteration) # hacky fencepost solution for 0th epoch plot if iteration == 0: writer.add_scalars('data/epoch_loss', {'action_loss': action_loss, 'score_loss': score_loss, 'total': (action_loss + score_loss)}, global_step=epoch) if iteration % 5000 == 0: print('Saving state, iter:', iteration) domain_specific_nets_state_dict = [] for domain_specific_net in domain_specific_nets: domain_specific_nets_state_dict.append(domain_specific_net.state_dict()) torch.save({ 'epoch': epoch, 'adnet_state_dict': net.state_dict(), 'adnet_domain_specific_state_dict': domain_specific_nets, 'optimizer_state_dict': optimizer.state_dict(), }, os.path.join(args.save_folder, args.save_file) + repr(iteration) + '_epoch' + repr(epoch) +'.pth') # final save torch.save({ 'epoch': epoch, 'adnet_state_dict': net.state_dict(), 'adnet_domain_specific_state_dict': domain_specific_nets, 'optimizer_state_dict': optimizer.state_dict(), }, os.path.join(args.save_folder, args.save_file) + '.pth') return net, domain_specific_nets, train_videos
parser.add_argument('--multidomain', default=True, type=str2bool, help='Separating weight for each videos (default) or not') parser.add_argument('--save_result_images', default=True, type=str2bool, help='Whether to save the results or not. Save folder: images/') parser.add_argument('--display_images', default=True, type=str2bool, help='Whether to display images or not') args = parser.parse_args() # Supervised Learning part if args.run_supervised: opts['minibatch_size'] = 128 # train with supervised learning _, _, train_videos = adnet_train_sl(args, opts) args.resume = os.path.join(args.save_folder, args.save_file) + '.pth' # reinitialize the network with network from SL net, domain_specific_nets = adnet(opts, trained_file=args.resume, random_initialize_domain_specific=True, multidomain=args.multidomain) args.start_epoch = 0 args.start_iter = 0 else: assert args.resume is not None, \ "Please put result of supervised learning or reinforcement learning with --resume (filename)" train_videos = get_train_videos(opts) opts['num_videos'] = len(train_videos['video_names']) if args.start_iter == 0: # means the weight came from the SL net, domain_specific_nets = adnet(opts, trained_file=args.resume, random_initialize_domain_specific=True, multidomain=args.multidomain) else: # resume the adnet net, domain_specific_nets = adnet(opts, trained_file=args.resume, random_initialize_domain_specific=False, multidomain=args.multidomain)
def do(): parser = argparse.ArgumentParser(description='ADNet training') parser.add_argument('--adnet_mot', default=False, type=str, help='Whether to test or train.') parser.add_argument('--test', default=False, type=str, help='Whether to test or train.') # parser.add_argument('--resume', default=None, type=str, help='Resume from checkpoint') parser.add_argument('--mot', default=False, type=bool, help='Perform MOT tracking') parser.add_argument('--resume', default='weights/ADNet_RL_FINAL.pth', type=str, help='Resume from checkpoint') parser.add_argument('--num_workers', default=4, type=int, help='Number of workers used in dataloading') parser.add_argument( '--start_iter', default=0, type=int, help= 'Begin counting iterations starting from this value (should be used with resume)' ) parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda to train model') parser.add_argument('--gamma', default=0.1, type=float, help='Gamma update for SGD') parser.add_argument('--visualize', default=True, type=str2bool, help='Use tensorboardx to for loss visualization') parser.add_argument( '--send_images_to_visualization', type=str2bool, default=False, help= 'Sample a random image from each 10th batch, send it to visdom after augmentations step' ) parser.add_argument('--save_folder', default='weights', help='Location to save checkpoint models') parser.add_argument('--save_file', default='ADNet_SL_MOT', type=str, help='save file part of file name for SL') parser.add_argument('--save_file_RL', default='ADNet_RL_', type=str, help='save file part of file name for RL') parser.add_argument('--start_epoch', default=0, type=int, help='Begin counting epochs starting from this value') parser.add_argument('--run_supervised', default=False, type=str2bool, help='Whether to run supervised learning or not') parser.add_argument( '--multidomain', default=True, type=str2bool, help='Separating weight for each videos (default) or not') parser.add_argument( '--save_result_images', default=False, type=str2bool, help='Whether to save the results or not. Save folder: images/') parser.add_argument('--display_images', default=True, type=str2bool, help='Whether to display images or not') args = parser.parse_args() # Supervised Learning part if args.run_supervised: opts['minibatch_size'] = 256 # train with supervised learning if args.test: args.save_file += "test" _, _, train_videos = adnet_test_sl(args, opts, mot=args.mot) else: if args.adnet_mot: _, _, train_videos = adnet_train_sl_mot(args, opts, mot=args.mot) else: _, _, train_videos = adnet_train_sl(args, opts, mot=args.mot) args.resume = os.path.join(args.save_folder, args.save_file) + '.pth' # reinitialize the network with network from SL net, domain_specific_nets = adnet( opts, trained_file=args.resume, random_initialize_domain_specific=True, multidomain=args.multidomain) args.start_epoch = 0 args.start_iter = 0 else: assert args.resume is not None, \ "Please put result of supervised learning or reinforcement learning with --resume (filename)" train_videos = get_train_videos(opts) opts['num_videos'] = len(train_videos['video_names']) if False and args.start_iter == 0: # means the weight came from the SL net, domain_specific_nets = adnet_mot( opts, trained_file=args.resume, random_initialize_domain_specific=True, multidomain=args.multidomain) else: # resume the adnet if args.adnet_mot: net, domain_specific_nets = adnet_mot( opts, trained_file=args.resume, random_initialize_domain_specific=False, multidomain=args.multidomain) else: net, domain_specific_nets = adnet( opts, trained_file=args.resume, random_initialize_domain_specific=False, multidomain=args.multidomain) if True: if args.cuda: net = nn.DataParallel(net) cudnn.benchmark = True net = net.cuda() # Reinforcement Learning part opts['minibatch_size'] = opts['train']['RL_steps'] if args.adnet_mot: net = adnet_train_rl_mot(net, domain_specific_nets, train_videos, opts, args, 2) else: net = adnet_train_rl(net, domain_specific_nets, train_videos, opts, args)
def adnet_test_sl(args, opts, mot): if torch.cuda.is_available(): if args.cuda: torch.set_default_tensor_type('torch.cuda.FloatTensor') if not args.cuda: print( "WARNING: It looks like you have a CUDA device, but aren't " + "using CUDA.\nRun with --cuda for optimal training speed.") torch.set_default_tensor_type('torch.FloatTensor') else: torch.set_default_tensor_type('torch.FloatTensor') if not os.path.exists(args.save_folder): os.mkdir(args.save_folder) if args.visualize: writer = SummaryWriter( log_dir=os.path.join('tensorboardx_log', args.save_file)) train_videos = get_train_videos(opts) opts['num_videos'] = len(train_videos['video_names']) net, domain_specific_nets = adnet(opts=opts, trained_file=args.resume, multidomain=args.multidomain) if args.cuda: net = nn.DataParallel(net) cudnn.benchmark = True net = net.cuda() net.eval() action_criterion = nn.CrossEntropyLoss() score_criterion = nn.BCELoss() print('generating Supervised Learning dataset..') # dataset = SLDataset(train_videos, opts, transform= if mot: datasets_pos, datasets_neg = initialize_pos_neg_dataset_mot( train_videos, opts, transform=ADNet_Augmentation(opts)) else: datasets_pos, datasets_neg = initialize_pos_neg_dataset( train_videos, opts, transform=ADNet_Augmentation(opts)) number_domain = opts['num_videos'] assert number_domain == len( datasets_pos ), "Num videos given in opts is incorrect! It should be {}".format( len(datasets_neg)) batch_iterators_pos_val = [] batch_iterators_neg_val = [] # calculating number of data len_dataset_pos = 0 len_dataset_neg = 0 for dataset_pos in datasets_pos: len_dataset_pos += len(dataset_pos) for dataset_neg in datasets_neg: len_dataset_neg += len(dataset_neg) epoch_size_pos = len_dataset_pos // opts['minibatch_size'] epoch_size_neg = len_dataset_neg // opts['minibatch_size'] epoch_size = epoch_size_pos + epoch_size_neg # 1 epoch, how many iterations print("1 epoch = " + str(epoch_size) + " iterations") max_iter = opts['numEpoch'] * epoch_size print("maximum iteration = " + str(max_iter)) data_loaders_pos_val = [] data_loaders_neg_val = [] for dataset_pos in datasets_pos: data_loaders_pos_val.append( data.DataLoader(dataset_pos, opts['minibatch_size'], num_workers=2, shuffle=True, pin_memory=True)) for dataset_neg in datasets_neg: data_loaders_neg_val.append( data.DataLoader(dataset_neg, opts['minibatch_size'], num_workers=2, shuffle=True, pin_memory=True)) net.eval() for curr_domain in range(number_domain): accuracy = [] action_loss_val = [] score_loss_val = [] # load ADNetDomainSpecific with video index if args.cuda: net.module.load_domain_specific(domain_specific_nets[curr_domain]) else: net.load_domain_specific(domain_specific_nets[curr_domain]) for i, temp in enumerate([ data_loaders_pos_val[curr_domain], data_loaders_neg_val[curr_domain] ]): dont_show = False for images, bbox, action_label, score_label, indices in tqdm(temp): images = images.to('cuda', non_blocking=True) action_label = action_label.to('cuda', non_blocking=True) score_label = score_label.float().to('cuda', non_blocking=True) # forward action_out, score_out = net(images) if i == 0: # if positive action_l = action_criterion(action_out, torch.max(action_label, 1)[1]) action_loss_val.append(action_l.item()) accuracy.append( int( action_label.argmax(axis=1).eq( action_out.argmax(axis=1)).sum()) / len(action_label)) score_l = score_criterion(score_out, score_label.reshape(-1, 1)) score_loss_val.append(score_l.item()) if args.display_images and not dont_show: if i == 0: dataset = datasets_pos[curr_domain] color = (0, 255, 0) conf = 1 else: dataset = datasets_neg[curr_domain] color = (0, 0, 255) conf = 0 for j, index in enumerate(indices): im = cv2.imread(dataset.train_db['img_path'][index]) bbox = dataset.train_db['bboxes'][index] action_label = np.array( dataset.train_db['labels'][index]) cv2.rectangle(im, (bbox[0], bbox[1]), (bbox[0] + bbox[2], bbox[1] + bbox[3]), color, 2) print("\n\nTarget actions: {}".format( action_label.argmax())) print("Predicted actions: {}".format( action_out.data[j].argmax())) print("Target conf: {}".format(conf)) print("Predicted conf: {}".format(score_out.data[j])) # print("Score loss: {}".format(score_l.item())) # print("Action loss: {}".format(action_l.item())) cv2.imshow("Test", im) key = cv2.waitKey(0) & 0xFF # if the `q` key was pressed, break from the loop if key == ord("q"): dont_show = True break elif key == ord("s"): cv2.imwrite( "vid {} t:{} p:{} c:{}.png".format( curr_domain, action_label.argmax(), action_out.data[i].argmax(), score_out.data[i].item()), im) print("Vid. {}".format(curr_domain)) print("\tAccuracy: {}".format(np.mean(accuracy))) print("\tScore loss: {}".format(np.mean(score_loss_val))) print("\tAction loss: {}".format(np.mean(action_loss_val))) sys.exit(0) return net, domain_specific_nets, train_videos
def process_adnet_test(videos_infos,dataset_start_id, v_start_id,v_end_id,train_videos,save_root, spend_times_share,vid_preds, opts,args, lock): siamesenet='' if args.useSiamese: siamesenet = SiameseNetwork().cuda() resume = args.weight_siamese # resume = False if resume: siamesenet.load_weights(resume) # print('Loading {}...'.format(args.weight_file)) net, domain_nets = adnet(opts, trained_file=args.weight_file, random_initialize_domain_specific=False) net.eval() if args.cuda: net = nn.DataParallel(net) cudnn.benchmark = True if args.cuda: net = net.cuda() if args.cuda: net.module.set_phase('test') else: net.set_phase('test') register_ILSVRC() cfg = get_cfg() cfg.merge_from_file("../../../configs/COCO-Detection/faster_rcnn_R_101_FPN_3x.yaml") cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model # Find a model from detectron2's model zoo. You can either use the https://dl.fbaipublicfiles.... url, or use the following shorthand # cfg.MODEL.WEIGHTS ="../datasets/tem/train_output/model_0449999.pth" cfg.MODEL.WEIGHTS = args.weight_detector cfg.MODEL.ROI_HEADS.NUM_CLASSES = 30 metalog = MetadataCatalog.get("ILSVRC_VID_val") predictor = DefaultPredictor(cfg) class_names = metalog.get("thing_classes", None) for vidx in range(v_start_id, v_end_id): # for vidx in range(20): vid_folder = videos_infos[vidx] if args.save_result_images_bool: args.save_result_images = os.path.join(save_root, train_videos['video_names'][vidx]) if not os.path.exists(args.save_result_images): os.makedirs(args.save_result_images) vid_pred,spend_time = adnet_test(net,predictor,siamesenet,metalog,class_names, vidx,vid_folder['img_files'], opts, args) try: lock.acquire() spend_times=spend_times_share[0].copy() spend_times['predict']+=spend_time['predict'] spend_times['n_predict_frames'] += spend_time['n_predict_frames'] spend_times['track'] += spend_time['track'] spend_times['n_track_frames'] += spend_time['n_track_frames'] spend_times['readframe'] += spend_time['readframe'] spend_times['n_readframe'] += spend_time['n_readframe'] spend_times['append'] += spend_time['append'] spend_times['n_append'] += spend_time['n_append'] spend_times['transform'] += spend_time['transform'] spend_times['n_transform'] += spend_time['n_transform'] spend_times['argmax_after_forward'] += spend_time['argmax_after_forward'] spend_times['n_argmax_after_forward'] += spend_time['n_argmax_after_forward'] spend_times['do_action'] += spend_time['do_action'] spend_times['n_do_action'] += spend_time['n_do_action'] spend_times_share[0]=spend_times vid_preds[vidx-dataset_start_id]=vid_pred except Exception as err: raise err finally: lock.release()
# dataset_root = os.path.join('../datasets/data', opts['test_db']) # vid_folders = [] # for filename in os.listdir(dataset_root): # if os.path.isdir(os.path.join(dataset_root,filename)): # vid_folders.append(filename) # vid_folders.sort(key=str.lower) # all_precisions = [] save_root = args.save_result_images # save_root_npy = args.save_result_npy opts['num_videos'] = 1 if not args.multi_cpu_eval: print('Loading {}...'.format(args.weight_file)) net, domain_nets = adnet(opts, trained_file=args.weight_file, random_initialize_domain_specific=False) net.eval() if args.cuda: net = nn.DataParallel(net) cudnn.benchmark = True if args.cuda: net = net.cuda() if args.cuda: net.module.set_phase('test') else: net.set_phase('test') if args.test1vid: vid_path = args.testVidPath vid_folder = vid_path.split('/')[-2] # vid_path = "../../../demo/examples/jiaotong2.avi"
def adnet_train_sl(args, opts, mot): if torch.cuda.is_available(): if args.cuda: torch.set_default_tensor_type('torch.cuda.FloatTensor') if not args.cuda: print( "WARNING: It looks like you have a CUDA device, but aren't " + "using CUDA.\nRun with --cuda for optimal training speed.") torch.set_default_tensor_type('torch.FloatTensor') else: torch.set_default_tensor_type('torch.FloatTensor') if not os.path.exists(args.save_folder): os.mkdir(args.save_folder) if args.visualize: writer = SummaryWriter( log_dir=os.path.join('tensorboardx_log', args.save_file)) train_videos = get_train_videos(opts) opts['num_videos'] = len(train_videos['video_names']) net, domain_specific_nets = adnet(opts=opts, trained_file=args.resume, multidomain=args.multidomain) if args.cuda: net = nn.DataParallel(net) cudnn.benchmark = True net = net.cuda() if args.cuda: optimizer = optim.Adam( [{ 'params': net.module.base_network.parameters(), 'lr': 1e-4 }, { 'params': net.module.fc4_5.parameters() }, { 'params': net.module.fc6.parameters() }, { 'params': net.module.fc7.parameters() }], # as action dynamic is zero, it doesn't matter lr=1e-3, weight_decay=opts['train']['weightDecay']) else: optimizer = optim.SGD([{ 'params': net.base_network.parameters(), 'lr': 1e-4 }, { 'params': net.fc4_5.parameters() }, { 'params': net.fc6.parameters() }, { 'params': net.fc7.parameters() }], lr=1e-3, momentum=opts['train']['momentum'], weight_decay=opts['train']['weightDecay']) if args.resume: # net.load_weights(args.resume) checkpoint = torch.load(args.resume) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) net.train() if not args.resume: print('Initializing weights...') if args.cuda: norm_std = 0.01 # fc 4 nn.init.normal_(net.module.fc4_5[0].weight.data, std=norm_std) net.module.fc4_5[0].bias.data.fill_(0.1) # fc 5 nn.init.normal_(net.module.fc4_5[3].weight.data, std=norm_std) net.module.fc4_5[3].bias.data.fill_(0.1) # fc 6 nn.init.normal_(net.module.fc6.weight.data, std=norm_std) net.module.fc6.bias.data.fill_(0) # fc 7 nn.init.normal_(net.module.fc7.weight.data, std=norm_std) net.module.fc7.bias.data.fill_(0) else: scal = torch.Tensor([0.01]) # fc 4 nn.init.normal_(net.fc4_5[0].weight.data) net.fc4_5[0].weight.data = net.fc4_5[ 0].weight.data * scal.expand_as(net.fc4_5[0].weight.data) net.fc4_5[0].bias.data.fill_(0.1) # fc 5 nn.init.normal_(net.fc4_5[3].weight.data) net.fc4_5[3].weight.data = net.fc4_5[ 3].weight.data * scal.expand_as(net.fc4_5[3].weight.data) net.fc4_5[3].bias.data.fill_(0.1) # fc 6 nn.init.normal_(net.fc6.weight.data) net.fc6.weight.data = net.fc6.weight.data * scal.expand_as( net.fc6.weight.data) net.fc6.bias.data.fill_(0) # fc 7 nn.init.normal_(net.fc7.weight.data) net.fc7.weight.data = net.fc7.weight.data * scal.expand_as( net.fc7.weight.data) net.fc7.bias.data.fill_(0) action_criterion = nn.CrossEntropyLoss() score_criterion = nn.BCELoss() print('generating Supervised Learning dataset..') # dataset = SLDataset(train_videos, opts, transform= if mot: datasets_pos, datasets_neg = initialize_pos_neg_dataset_mot( train_videos, opts, transform=ADNet_Augmentation(opts)) else: datasets_pos, datasets_neg = initialize_pos_neg_dataset( train_videos, opts, transform=ADNet_Augmentation(opts)) number_domain = opts['num_videos'] assert number_domain == len( datasets_pos ), "Num videos given in opts is incorrect! It should be {}".format( len(datasets_neg)) batch_iterators_pos_train = [] batch_iterators_neg_train = [] # calculating number of data len_dataset_pos = 0 len_dataset_neg = 0 for dataset_pos in datasets_pos: len_dataset_pos += len(dataset_pos) for dataset_neg in datasets_neg: len_dataset_neg += len(dataset_neg) epoch_size_pos = len_dataset_pos // opts['minibatch_size'] epoch_size_neg = len_dataset_neg // opts['minibatch_size'] epoch_size = epoch_size_pos + epoch_size_neg # 1 epoch, how many iterations print("1 epoch = " + str(epoch_size) + " iterations") max_iter = opts['numEpoch'] * epoch_size print("maximum iteration = " + str(max_iter)) data_loaders_pos_train = [] data_loaders_pos_val = [] data_loaders_neg_train = [] data_loaders_neg_val = [] for dataset_pos in datasets_pos: num_val = int(opts['val_percent'] * len(dataset_pos)) num_train = len(dataset_pos) - num_val train, valid = torch.utils.data.random_split(dataset_pos, [num_train, num_val]) data_loaders_pos_train.append( data.DataLoader(train, opts['minibatch_size'], num_workers=2, shuffle=True, pin_memory=True)) data_loaders_pos_val.append( data.DataLoader(valid, opts['minibatch_size'], num_workers=0, shuffle=True, pin_memory=False)) for dataset_neg in datasets_neg: num_val = int(opts['val_percent'] * len(dataset_neg)) num_train = len(dataset_neg) - num_val train, valid = torch.utils.data.random_split(dataset_neg, [num_train, num_val]) data_loaders_neg_train.append( data.DataLoader(train, opts['minibatch_size'], num_workers=1, shuffle=True, pin_memory=True)) data_loaders_neg_val.append( data.DataLoader(valid, opts['minibatch_size'], num_workers=0, shuffle=True, pin_memory=False)) epoch = args.start_epoch if epoch != 0 and args.start_iter == 0: start_iter = epoch * epoch_size else: start_iter = args.start_iter which_dataset = list(np.full(epoch_size_pos, fill_value=1)) which_dataset.extend(np.zeros(epoch_size_neg, dtype=int)) shuffle(which_dataset) which_dataset = torch.Tensor(which_dataset).cuda() which_domain = np.random.permutation(number_domain) action_loss_tr = 0 score_loss_tr = 0 # training loop time_arr = np.zeros(10) for iteration in tqdm(range(start_iter, max_iter)): t0 = time.time() if args.multidomain: curr_domain = which_domain[iteration % len(which_domain)] else: curr_domain = 0 # if new epoch (not including the very first iteration) if (iteration != start_iter) and (iteration % epoch_size == 0): epoch += 1 shuffle(which_dataset) np.random.shuffle(which_domain) print('Saving state, epoch: {}'.format(epoch)) domain_specific_nets_state_dict = [] for domain_specific_net in domain_specific_nets: domain_specific_nets_state_dict.append( domain_specific_net.state_dict()) torch.save( { 'epoch': epoch, 'adnet_state_dict': net.state_dict(), 'adnet_domain_specific_state_dict': domain_specific_nets, 'optimizer_state_dict': optimizer.state_dict(), }, os.path.join(args.save_folder, args.save_file) + 'epoch' + repr(epoch) + '.pth') # VAL for curr_domain_temp in range(number_domain): accuracy = [] action_loss_val = [] score_loss_val = [] # load ADNetDomainSpecific with video index if args.cuda: net.module.load_domain_specific( domain_specific_nets[curr_domain_temp]) else: net.load_domain_specific( domain_specific_nets[curr_domain_temp]) for i, temp in enumerate([ data_loaders_pos_val[curr_domain_temp], data_loaders_neg_val[curr_domain_temp] ]): for images, bbox, action_label, score_label, _ in temp: images = images.to('cuda', non_blocking=True) action_label = action_label.to('cuda', non_blocking=True) score_label = score_label.float().to('cuda', non_blocking=True) # forward action_out, score_out = net(images) if i == 0: # if positive action_l = action_criterion( action_out, torch.max(action_label, 1)[1]) accuracy.append( int( action_label.argmax(axis=1).eq( action_out.argmax(axis=1)).sum()) / len(action_label)) action_loss_val.append(action_l.item()) score_l = score_criterion(score_out, score_label.reshape(-1, 1)) score_loss_val.append(score_l.item()) print("Vid. {}".format(curr_domain)) print("\tAccuracy: {}".format(np.mean(accuracy))) print("\tScore loss: {}".format(np.mean(score_loss_val))) print("\tAction loss: {}".format(np.mean(action_loss_val))) if args.visualize: writer.add_scalars( 'data/val_video_{}'.format(curr_domain_temp), { 'action_loss_val': np.mean(action_loss_val), 'score_loss_val': np.mean(score_loss_val), 'total_val': np.mean(score_loss_val) + np.mean(action_loss_val), 'accuracy': np.mean(accuracy) }, global_step=epoch) if args.visualize: writer.add_scalars('data/epoch_loss', { 'action_loss_tr': action_loss_tr / epoch_size_pos, 'score_loss_tr': score_loss_tr / epoch_size, 'total_tr': action_loss_tr / epoch_size_pos + score_loss_tr / epoch_size }, global_step=epoch) # reset epoch loss counters action_loss_tr = 0 score_loss_tr = 0 # if new epoch (including the first iteration), initialize the batch iterator # or just resuming where batch_iterator_pos and neg haven't been initialized if len(batch_iterators_pos_train) == 0 or len( batch_iterators_neg_train) == 0: # create batch iterator for data_loader_pos in data_loaders_pos_train: batch_iterators_pos_train.append(iter(data_loader_pos)) for data_loader_neg in data_loaders_neg_train: batch_iterators_neg_train.append(iter(data_loader_neg)) # if not batch_iterators_pos_train[curr_domain]: # # create batch iterator # batch_iterators_pos_train[curr_domain] = iter(data_loaders_pos_train[curr_domain]) # # if not batch_iterators_neg_train[curr_domain]: # # create batch iterator # batch_iterators_neg_train[curr_domain] = iter(data_loaders_neg_train[curr_domain]) # load train data if which_dataset[iteration % len(which_dataset)]: # if positive try: images, bbox, action_label, score_label, vid_idx = next( batch_iterators_pos_train[curr_domain]) except StopIteration: batch_iterators_pos_train[curr_domain] = iter( data_loaders_pos_train[curr_domain]) images, bbox, action_label, score_label, vid_idx = next( batch_iterators_pos_train[curr_domain]) else: try: images, bbox, action_label, score_label, vid_idx = next( batch_iterators_neg_train[curr_domain]) except StopIteration: batch_iterators_neg_train[curr_domain] = iter( data_loaders_neg_train[curr_domain]) images, bbox, action_label, score_label, vid_idx = next( batch_iterators_neg_train[curr_domain]) # TODO: check if this requires grad is really false like in Variable if args.cuda: images = images.to('cuda', non_blocking=True) # bbox = torch.Tensor(bbox.cuda()) action_label = action_label.to('cuda', non_blocking=True) score_label = score_label.float().to('cuda', non_blocking=True) else: images = torch.Tensor(images) bbox = torch.Tensor(bbox) action_label = torch.Tensor(action_label) score_label = torch.Tensor(score_label) # TRAIN net.train() action_out, score_out = net(images) # load ADNetDomainSpecific with video index if args.cuda: net.module.load_domain_specific(domain_specific_nets[curr_domain]) else: net.load_domain_specific(domain_specific_nets[curr_domain]) # backprop optimizer.zero_grad() score_l = score_criterion(score_out, score_label.reshape(-1, 1)) if which_dataset[iteration % len(which_dataset)]: # if positive action_l = action_criterion(action_out, torch.max(action_label, 1)[1]) accuracy = int( action_label.argmax(axis=1).eq( action_out.argmax(axis=1)).sum()) / len(action_label) loss = action_l + score_l else: action_l = -1 accuracy = -1 loss = score_l loss.backward() optimizer.step() if action_l != -1: action_loss_tr += action_l.item() score_loss_tr += score_l.item() # save the ADNetDomainSpecific back to their module if args.cuda: domain_specific_nets[curr_domain].load_weights_from_adnet( net.module) else: domain_specific_nets[curr_domain].load_weights_from_adnet(net) if args.visualize: if action_l != -1: writer.add_scalars( 'data/iter_loss', { 'action_loss_tr': action_l.item(), 'score_loss_tr': score_l.item(), 'total_tr': (action_l.item() + score_l.item()) }, global_step=iteration) else: writer.add_scalars('data/iter_loss', { 'score_loss_tr': score_l.item(), 'total_tr': score_l.item() }, global_step=iteration) if accuracy >= 0: writer.add_scalars('data/iter_acc', {'accuracy_tr': accuracy}, global_step=iteration) t1 = time.time() time_arr[iteration % 10] = t1 - t0 if iteration % 10 == 0: # print('Avg. 10 iter time: %.4f sec.' % time_arr.sum()) # print('iter ' + repr(iteration) + ' || Loss: %.4f ||' % (loss.data.item()), end=' ') if args.visualize and args.send_images_to_visualization: random_batch_index = np.random.randint(images.size(0)) writer.add_image('image', images.data[random_batch_index].cpu().numpy(), random_batch_index) if args.visualize: writer.add_scalars('data/time', {'time_10_it': time_arr.sum()}, global_step=iteration) if iteration % 5000 == 0: print('Saving state, iter:', iteration) domain_specific_nets_state_dict = [] for domain_specific_net in domain_specific_nets: domain_specific_nets_state_dict.append( domain_specific_net.state_dict()) torch.save( { 'epoch': epoch, 'adnet_state_dict': net.state_dict(), 'adnet_domain_specific_state_dict': domain_specific_nets, 'optimizer_state_dict': optimizer.state_dict(), }, os.path.join(args.save_folder, args.save_file) + repr(iteration) + '_epoch' + repr(epoch) + '.pth') # final save torch.save( { 'epoch': epoch, 'adnet_state_dict': net.state_dict(), 'adnet_domain_specific_state_dict': domain_specific_nets, 'optimizer_state_dict': optimizer.state_dict(), }, os.path.join(args.save_folder, args.save_file) + '.pth') return net, domain_specific_nets, train_videos
def main(): args = get_args() torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) if args.cuda and torch.cuda.is_available() and args.cuda_deterministic: torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True # log_dir = os.path.expanduser(args.log_dir) log_dir=args.log_dir eval_log_dir = log_dir + "_eval" utils.cleanup_log_dir(log_dir) utils.cleanup_log_dir(eval_log_dir) save_path = os.path.join(args.save_dir) try: os.makedirs(save_path) except OSError: pass torch.set_num_threads(1) device = torch.device("cuda:0" if args.cuda else "cpu") # envs = make_vec_envs(args.env_name, args.seed, args.num_processes, # args.gamma, args.log_dir, device, False) env = gym.make(args.env_name).unwrapped parser = argparse.ArgumentParser() parser.add_argument('--eval_imgs', default=0, type=int, help='the num of imgs that picked from val.txt, 0 represent all imgs') parser.add_argument('--gt_skip', default=1, type=int, help='frame sampling frequency') parser.add_argument('--dataset_year', default=2222, type=int, help='dataset version, like ILSVRC2015, ILSVRC2017, 2222 means train.txt') args2 = parser.parse_args(['--eval_imgs', '0', '--gt_skip', '1', '--dataset_year', '2222']) videos_infos, _ = get_ILSVRC_eval_infos(args2) mean = np.array(opts['means'], dtype=np.float32) mean = torch.from_numpy(mean).cuda() transform = ADNet_Augmentation2(opts, mean) # for en in envs: # en.init_data(videos_infos, opts, transform, do_action,overlap_ratio) env.init_data(videos_infos, opts, transform, do_action, overlap_ratio) net, _ = adnet(opts, trained_file=args.resume, random_initialize_domain_specific=True, multidomain=False) # net = net.cuda() actor_critic = Policy( env.observation_space.shape, env.action_space, base=net, base_kwargs={'recurrent': args.recurrent_policy}) actor_critic.to(device) # if args.algo == 'a2c': # agent = algo.A2C_ACKTR( # actor_critic, # args.value_loss_coef, # args.entropy_coef, # lr=args.lr, # eps=args.eps, # alpha=args.alpha, # max_grad_norm=args.max_grad_norm) # elif args.algo == 'ppo': agent = algo.PPO( actor_critic, args.clip_param, args.ppo_epoch, args.num_mini_batch, args.value_loss_coef, args.entropy_coef, lr=args.lr, eps=args.eps, max_grad_norm=args.max_grad_norm) # elif args.algo == 'acktr': # agent = algo.A2C_ACKTR( # actor_critic, args.value_loss_coef, args.entropy_coef, acktr=True) # if args.gail: # assert len(envs.observation_space.shape) == 1 # discr = gail.Discriminator( # envs.observation_space.shape[0] + envs.action_space.shape[0], 100, # device) # file_name = os.path.join( # args.gail_experts_dir, "trajs_{}.pt".format( # args.env_name.split('-')[0].lower())) # # expert_dataset = gail.ExpertDataset( # file_name, num_trajectories=4, subsample_frequency=20) # drop_last = len(expert_dataset) > args.gail_batch_size # gail_train_loader = torch.utils.data.DataLoader( # dataset=expert_dataset, # batch_size=args.gail_batch_size, # shuffle=True, # drop_last=drop_last) rollouts = RolloutStorage(args.num_steps, opts['inputSize_transpose'], env.action_space, ) # episode_rewards = deque(maxlen=10) # # start = time.time() # num_updates = int( # args.num_env_steps) // args.num_steps // args.num_processes # for j in range(num_updates): # epoch=0 for epoch in range(0, args.num_epoch): env.reset_env() rollouts.reset_storage() obs = env.reset() rollouts.obs[0].copy_(obs) rollouts.to(device) j=-1 va = 0 n_va = 0 va_epoch = 0 n_va_epoch = 0 while True: j+=1 #current clip number actor_critic.base.reset_action_dynamic() if args.use_linear_lr_decay: # decrease learning rate linearly # utils.update_linear_schedule( # agent.optimizer, j, num_updates, # agent.optimizer.lr if args.algo == "acktr" else args.lr) utils.update_linear_schedule( agent.optimizer, j, len(videos_infos), agent.optimizer.lr if args.algo == "acktr" else args.lr) # for step in range(args.num_steps): box_history_clip = [] t = 0 step=0 while True: #one clip # Sample actions with torch.no_grad(): value, action, action_log_prob, recurrent_hidden_states = actor_critic.act( rollouts.obs[step], rollouts.recurrent_hidden_states[step], rollouts.masks[step]) va+=value n_va+=1 # Obser reward and next obs obs, new_state,reward, done, infos = env.step(action) reward=torch.Tensor([reward]) # for info in infos: # if 'episode' in info.keys(): # episode_rewards.append(info['episode']['r']) # If done then clean the history of observations. # masks = torch.FloatTensor( # [[0.0] if done_ else [1.0] for done_ in done]) # bad_masks = torch.FloatTensor( # [[0.0] if 'bad_transition' in info.keys() else [1.0] # for info in infos]) masks = torch.FloatTensor( [1.0]) bad_masks = torch.FloatTensor( [1.0]) rollouts.insert(obs, recurrent_hidden_states, action, action_log_prob, value, reward, masks, bad_masks) if ((action != opts['stop_action']) and any( (np.array(new_state).round() == x).all() for x in np.array(box_history_clip).round())): action = opts['stop_action'] reward, done, finish_epoch = env.go_to_next_frame() infos['finish_epoch'] = finish_epoch if t > opts['num_action_step_max']: #todo: in this situation, reward/feedback should be punished. action = opts['stop_action'] reward, done, finish_epoch = env.go_to_next_frame() infos['finish_epoch'] = finish_epoch box_history_clip.append(list(new_state)) t += 1 if action == opts['stop_action']:#finish one frame t = 0 box_history_clip = [] rollouts.obs[rollouts.get_step()].copy_(env.get_current_patch()) if done: # if finish the clip # rollouts.obs[rollouts.get_step()].copy_(obs) rollouts.obs[rollouts.get_step()].copy_(env.get_current_patch()) break with torch.no_grad(): # next_value = actor_critic.get_value( # rollouts.obs[-1], rollouts.recurrent_hidden_states[-1], # rollouts.masks[-1]).detach() next_value = actor_critic.get_value( rollouts.obs[rollouts.get_step()], rollouts.recurrent_hidden_states[rollouts.get_step()], rollouts.masks[rollouts.get_step()]).detach() # if args.gail: # if j >= 10: # envs.venv.eval() # # gail_epoch = args.gail_epoch # if j < 10: # gail_epoch = 100 # Warm up # for _ in range(gail_epoch): # discr.update(gail_train_loader, rollouts, # utils.get_vec_normalize(envs)._obfilt) # # for step in range(args.num_steps): # rollouts.rewards[step] = discr.predict_reward( # rollouts.obs[step], rollouts.actions[step], args.gamma, # rollouts.masks[step]) rollouts.compute_returns(next_value, args.use_gae, args.gamma, args.gae_lambda, args.use_proper_time_limits) value_loss, action_loss, dist_entropy = agent.update(rollouts) rollouts.obs[rollouts.get_step()].copy_(env.get_current_patch()) rollouts.after_update() if n_va>=100: ave_va=va/n_va print("current clip: %d, n_va: %d, cur v: %.2f, cur ave v: %.2f"%(j,n_va,value,ave_va)) va_epoch+=va n_va_epoch+=n_va va=0 n_va=0 if infos['finish_epoch']: ave_va_epoch=va_epoch/n_va_epoch print("epoch: %d, ave value of v: %.2f"%(epoch,ave_va_epoch)) va_epoch=0 n_va_epoch=0 break # save for every interval-th episode or for the last epoch if (j % args.save_interval == 0 ) and args.save_dir != "": torch.save({ 'epoch': epoch, 'adnet_state_dict': actor_critic.base.state_dict(), # 'adnet_domain_specific_state_dict': domain_specific_nets, # 'optimizer_state_dict': optimizer.state_dict(), }, os.path.join(save_path,'ADNet_RL_epoch' + repr(epoch) + "_" + repr(j) + '.pth')) # torch.save([ # actor_critic.base, # getattr(utils.get_vec_normalize(envs), 'ob_rms', None) # ], os.path.join(save_path, args.env_name + ".pt")) # if j % args.log_interval == 0 and len(episode_rewards) > 1: # total_num_steps = (j + 1) * args.num_processes * args.num_steps # end = time.time() # print( # "Updates {}, num timesteps {}, FPS {} \n Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}\n" # .format(j, total_num_steps, # int(total_num_steps / (end - start)), # len(episode_rewards), np.mean(episode_rewards), # np.median(episode_rewards), np.min(episode_rewards), # np.max(episode_rewards), dist_entropy, value_loss, # action_loss)) # if (args.eval_interval is not None and len(episode_rewards) > 1 # and j % args.eval_interval == 0): # ob_rms = utils.get_vec_normalize(envs).ob_rms # evaluate(actor_critic, ob_rms, args.env_name, args.seed, # args.num_processes, eval_log_dir, device) torch.save({ 'epoch': epoch, 'adnet_state_dict': actor_critic.base.state_dict(), # 'adnet_domain_specific_state_dict': domain_specific_nets, # 'optimizer_state_dict': optimizer.state_dict(), }, os.path.join(save_path,'ADNet_RL_final.pth'))