def main_merge(): global args, best_corr args.store_name = '{}_merged'.format(args.model) args.store_name = args.store_name + datetime.now().strftime('_%m-%d_%H-%M') args.start_epoch = 0 check_rootfolders(args) model = Baseline(args.img_feat_size, args.au_feat_size) model = torch.nn.DataParallel(model).cuda() optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) if args.use_multistep: scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, args.step_milestones, args.step_decay) # ckpt structure {epoch, state_dict, optimizer, best_corr} if args.resume and os.path.isfile(args.resume): print('Load checkpoint:', args.resume) ckpt = torch.load(args.resume) args.start_epoch = ckpt['epoch'] best_corr = ckpt['best_corr'] model.load_state_dict(ckpt['state_dict']) optimizer.load_state_dict(ckpt['optimizer']) print('Loaded ckpt at epoch:', args.start_epoch) # initialize datasets train_loader = torch.utils.data.DataLoader(dataset=EEV_Dataset( csv_path=[args.train_csv, args.val_csv], vidmap_path=[args.train_vidmap, args.val_vidmap], image_feat_path=args.image_features, audio_feat_path=args.audio_features, mode='merge'), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=False, drop_last=True) log_training = open( os.path.join(args.root_log, args.store_name, 'log.csv'), 'w') with open(os.path.join(args.root_log, args.store_name, 'args.txt'), 'w') as f: f.write(str(args)) tb_writer = SummaryWriter( log_dir=os.path.join(args.root_log, args.store_name)) for epoch in range(args.start_epoch, args.epochs): train(train_loader, model, optimizer, epoch, log_training, tb_writer) save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_corr': 0.0, }, False) if args.use_multistep: scheduler.step()
def main(): net = Baseline(num_classes=culane.num_classes, deep_base=args['deep_base']).cuda() print('load checkpoint \'%s.pth\' for evaluation' % args['checkpoint']) pretrained_dict = torch.load(os.path.join(ckpt_path, exp_name, args['checkpoint'] + '_checkpoint.pth')) pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items()} net.load_state_dict(pretrained_dict) net.eval() save_dir = os.path.join(ckpt_path, exp_name, 'vis_%s_test' % args['checkpoint']) check_mkdir(save_dir) log_path = os.path.join(save_dir, str(datetime.datetime.now()) + '.log') data_list = [l.strip('\n') for l in open(os.path.join(culane.root, culane.list, 'test_gt.txt'), 'r')] loss_record = AverageMeter() gt_all, prediction_all=[], [] for idx in range(len(data_list)): print('evaluating %d / %d' % (idx + 1, len(data_list))) img = Image.open(culane.root + data_list[idx].split(' ')[0]).convert('RGB') gt = Image.open(culane.root + data_list[idx].split(' ')[1]) img, gt = val_joint_transform(img, gt) with torch.no_grad(): img_var = Variable(img_transform(img).unsqueeze(0)).cuda() gt_var = Variable(mask_transform(gt).unsqueeze(0)).cuda() prediction = net(img_var)[0] loss = criterion(prediction, gt_var) loss_record.update(loss.data, 1) scoremap = F.softmax(prediction, dim=1).data.squeeze().cpu().numpy() prediction = prediction.data.max(1)[1].squeeze().cpu().numpy().astype(np.uint8) prediction_all.append(prediction) gt_all.append(np.array(gt)) if args['save_results']: check_mkdir(save_dir + data_list[idx].split(' ')[0][:-10]) out_file = open(os.path.join(save_dir, data_list[idx].split(' ')[0][1:-4] + '.lines.txt'), 'w') prob2lines(scoremap, out_file) acc, acc_cls, mean_iu, fwavacc = evaluation(prediction_all, gt_all, culane.num_classes) log = 'val results: loss %.5f acc %.5f acc_cls %.5f mean_iu %.5f fwavacc %.5f' % \ (loss_record.avg, acc, acc_cls, mean_iu, fwavacc) print(log) open(log_path, 'w').write(log + '\n')
def test(): # Prepare env env = create_env() h, w, c = env.observation_space.shape # Load 5 best models device = torch.device("cpu") model_dir = "./policy_grad" model_fns = {} for fn in os.listdir(model_dir): if fn.endswith('.pth'): score = fn.split("_")[-1][:-4] model_fns[fn] = float(score) top_5 = heapq.nlargest(3, model_fns, key=model_fns.get) models = [] for fn in top_5: path = os.path.join(model_dir, fn) model = Baseline(h, w).to(device) model.load_state_dict(torch.load(path, map_location='cpu')) model.eval() models.append(model) # Watch race car perform state = env.reset().transpose((2, 0, 1)) state = torch.tensor([state], dtype=torch.float, device=device) total_reward = 0 for t in count(): # Select and perform an action votes = [] for model in models: pi, _ = model(state) votes.append(pi.argmax().item()) action_idx = Counter(votes).most_common(1)[0][0] action = index_to_action(action_idx) state, reward, done, _ = env.step(action) env.render() # Update state = state.transpose((2, 0, 1)) state = torch.tensor([state], dtype=torch.float, device=device) total_reward += reward if done: break print("Total reward: {}".format(total_reward))
def main(): net = Baseline(num_classes=culane.num_classes, deep_base=args['deep_base']).cuda().train() net = DataParallelWithCallback(net) optimizer = optim.SGD([{ 'params': [ param for name, param in net.named_parameters() if name[-4:] == 'bias' ], 'lr': 2 * args['base_lr'] }, { 'params': [ param for name, param in net.named_parameters() if name[-4:] != 'bias' ], 'lr': args['base_lr'] }], momentum=args['momentum']) if len(args['checkpoint']) > 0: print('training resumes from \'%s\'' % args['checkpoint']) net.load_state_dict( torch.load( os.path.join(ckpt_path, exp_name, args['checkpoint'] + '_checkpoint.pth'))) optimizer.load_state_dict( torch.load( os.path.join(ckpt_path, exp_name, args['checkpoint'] + '_checkpoint_optim.pth'))) optimizer.param_groups[0]['lr'] = 2 * args['base_lr'] optimizer.param_groups[1]['lr'] = args['base_lr'] check_mkdir(os.path.join(ckpt_path, exp_name)) open(log_path, 'w').write(str(args) + '\n\n') train(net, optimizer)
def main_test(): print('Running test...') torch.multiprocessing.set_sharing_strategy('file_system') model = Baseline() if args.use_swa: model = torch.optim.swa_utils.AveragedModel(model) model = torch.nn.DataParallel(model).cuda() # ckpt structure {epoch, state_dict, optimizer, best_corr} if args.resume and os.path.isfile(args.resume): print('Load checkpoint:', args.resume) ckpt = torch.load(args.resume) args.start_epoch = ckpt['epoch'] best_corr = ckpt['best_corr'] model.load_state_dict(ckpt['state_dict']) print('Loaded ckpt at epoch:', args.start_epoch) else: print('No model given. Abort!') exit(1) test_loader = torch.utils.data.DataLoader( dataset=EEV_Dataset( csv_path=None, vidmap_path=args.test_vidmap, image_feat_path=args.image_features, audio_feat_path=args.audio_features, mode='test', test_freq=args.test_freq ), batch_size=None, shuffle=False, num_workers=args.workers, pin_memory=False ) model.eval() batch_time = AverageMeter() t_start = time.time() outputs = [] with torch.no_grad(): for i, (img_feat, au_feat, frame_count, vid) in enumerate(test_loader): img_feat = torch.stack(img_feat).cuda() au_feat = torch.stack(au_feat).cuda() assert len(au_feat.size()) == 3, 'bad auf %s' % (vid) output = model(img_feat, au_feat) # [Clip S 15] # rearrange and remove extra padding in the end output = rearrange(output, 'Clip S C -> (Clip S) C') output = torch.cat([output, output[-1:]]) # repeat the last frame to avoid missing if args.train_freq < args.test_freq: # print('interpolating:', output.size()[0], frame_count) output = interpolate_output(output, args.train_freq, 6) # print('Interpolated:', output.size()[0], frame_count) # truncate extra frames assert output.size(0) >= frame_count, '{}/{}'.format(output.size(0), frame_count) output = output[:frame_count] outputs.append((vid, frame_count, output.cpu().detach().numpy())) # update statistics batch_time.update(time.time() - t_start) t_start = time.time() if i % args.print_freq == 0: output = ('Test: [{0}/{1}]\t' 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f})'.format( i, len(test_loader), batch_time=batch_time)) print(output) time_stamps = [0, 166666, 333333, 500000, 666666, 833333] time_step = 1000000 # time starts at 0 header = 'Video ID,Timestamp (milliseconds),amusement,anger,awe,concentration,confusion,contempt,contentment,disappointment,doubt,elation,interest,pain,sadness,surprise,triumph\n' final_res = {} for vid, frame_count, out in outputs:# videos video_time = frame_count // 6 + 1 # print('video', vid, video_time) entry_count = 0 for t in range(video_time): # seconds for i in range(6): # frames timestamp = time_step * t + time_stamps[i] fcc = t * 6 + i if fcc >= frame_count: continue # print('Frame count', frame_count) frame_output = out[fcc] frame_output = [str(x) for x in frame_output] temp = '{vid},{timestamp},'.format(vid=vid,timestamp=timestamp) + ','.join(frame_output) + '\n' # file.write(temp) if vid in final_res: final_res[vid].append(temp) else: final_res[vid] = [temp] entry_count += 1 assert entry_count == frame_count # fixed for now missing = [('WKXrnB7alT8', 2919), ('o0ooW14pIa4', 3733), ('GufMoL_MuNE',2038), ('Uee0Tv1rTz8', 1316), ('ScvvOWtb04Q', 152), ('R9kJlLungmo', 3609),('QMW3GuohzzE', 822), ('fjJYTW2n6rk', 4108), ('rbTIMt0VcLw', 1084),('L9cdaj74kLo', 3678), ('l-ka23gU4NA', 1759)] for vid, length in missing: video_time = length // 6 + 1 # print('video', vid, video_time) for t in range(video_time): # seconds for i in range(6): # frames timestamp = time_step * t + time_stamps[i] fcc = t * 6 + i if fcc >= length: continue frame_output = ',0'*15 temp = '{vid},{timestamp}'.format(vid=vid, timestamp=timestamp) + frame_output + '\n' # file.write(temp) if vid in final_res: final_res[vid].append(temp) else: final_res[vid] = [temp] print('Write test outputs...') with open('test_output.csv', 'w') as file: file.write(header) temp_vidmap = [x.strip().split(' ') for x in open(args.test_vidmap)] temp_vidmap = [x[0] for x in temp_vidmap] for vid in tqdm(temp_vidmap): for entry in final_res[vid]: file.write(entry)
def main_train(config, checkpoint_dir=None): global args, best_corr best_corr = 0.0 args.store_name = '{}'.format(args.model) args.store_name = args.store_name + datetime.now().strftime('_%m-%d_%H-%M-%S') args.start_epoch = 0 # check_rootfolders(args) if args.model == 'Baseline': model = Baseline() elif args.model == 'TCFPN': model = TCFPN(layers=[48, 64, 96], in_channels=(2048 + 128), num_classes=15, kernel_size=11) model = torch.nn.DataParallel(model).cuda() if config['optimizer'] == 'adam': optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) elif config['optimizer'] == 'adamw': optimizer = torch.optim.AdamW(model.parameters(), lr=config['lr']) # custom optimizer if args.use_sam: base_optim = torch.optim.Adam optimizer = SAM(model.parameters(), base_optim, lr=config['lr']) # custom lr scheduler if args.use_cos_wr: scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=args.cos_wr_t0,T_mult=args.cos_wr_t_mult) elif args.use_cos: scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.cos_t_max) # SWA if args.use_swa: swa_model = torch.optim.swa_utils.AveragedModel(model) swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, swa_lr=config['lr']) # ckpt structure {epoch, state_dict, optimizer, best_corr} # if args.resume and os.path.isfile(args.resume): # print('Load checkpoint:', args.resume) # ckpt = torch.load(args.resume) # args.start_epoch = ckpt['epoch'] # best_corr = ckpt['best_corr'] # model.load_state_dict(ckpt['state_dict']) # optimizer.load_state_dict(ckpt['optimizer']) # print('Loaded ckpt at epoch:', args.start_epoch) if checkpoint_dir: model_state, optimizer_state = torch.load( os.path.join(checkpoint_dir, "checkpoint")) model.load_state_dict(model_state) optimizer.load_state_dict(optimizer_state) # initialize datasets train_loader = torch.utils.data.DataLoader( dataset=EEV_Dataset( csv_path=args.train_csv, vidmap_path=args.train_vidmap, image_feat_path=args.image_features, audio_feat_path=args.audio_features, mode='train', lpfilter=args.lp_filter ), batch_size=config['batch_size'], shuffle=True, num_workers=args.workers, pin_memory=False, drop_last=True ) val_loader = torch.utils.data.DataLoader( dataset=EEV_Dataset( csv_path=args.val_csv, vidmap_path=args.val_vidmap, image_feat_path=args.image_features, audio_feat_path=args.audio_features, mode='val' ), batch_size=None, shuffle=False, num_workers=args.workers, pin_memory=False ) accuracy = correlation # with open(os.path.join(args.root_log, args.store_name, 'args.txt'), 'w') as f: # f.write(str(args)) # tb_writer = SummaryWriter(log_dir=os.path.join(args.root_log, args.store_name)) for epoch in range(args.start_epoch, args.epochs): # train train(train_loader, model, optimizer, epoch, None, None) # do lr scheduling after epoch if args.use_swa and epoch >= args.swa_start: print('swa stepping...') swa_model.update_parameters(model) swa_scheduler.step() elif args.use_cos_wr: print('cos warm restart (T0:{} Tm:{}) stepping...'.format(args.cos_wr_t0, args.cos_wr_t_mult)) scheduler.step() elif args.use_cos: print('cos (Tmax:{}) stepping...'.format(args.cos_t_max)) scheduler.step() # validate if args.use_swa and epoch >= args.swa_start: # validate use swa model corr, loss = validate(val_loader, swa_model, accuracy, epoch, None, None) else: corr, loss = validate(val_loader, model, accuracy, epoch, None, None) is_best = corr > best_corr best_corr = max(corr, best_corr) # tb_writer.add_scalar('acc/validate_corr_best', best_corr, epoch) # output_best = 'Best corr: %.4f\n' % (best_corr) # print(output_best) # save_checkpoint({ # 'epoch': epoch + 1, # 'state_dict': model.state_dict(), # 'optimizer': optimizer.state_dict(), # 'best_corr': best_corr, # }, is_best) with tune.checkpoint_dir(epoch) as checkpoint_dir: path = os.path.join(checkpoint_dir, "checkpoint") if is_best: path = os.path.join(checkpoint_dir, "checkpoint_best") torch.save((model.state_dict(), optimizer.state_dict()), path) tune.report(loss=loss, accuracy=corr, best_corr=best_corr)
def main_train(): global args, best_corr args.store_name = '{}'.format(args.model) args.store_name = args.store_name + datetime.now().strftime( '_%m-%d_%H-%M-%S') args.start_epoch = 0 if not args.val_only: check_rootfolders(args) if args.model == 'Baseline': if args.cls_indices: model = Baseline(args.img_feat_size, args.au_feat_size, num_classes=len(args.cls_indices)) else: print('Feature size:', args.img_feat_size, args.au_feat_size) model = Baseline(args.img_feat_size, args.au_feat_size) elif args.model == 'TCFPN': model = TCFPN(layers=[48, 64, 96], in_channels=(128), num_classes=15, kernel_size=11) elif args.model == 'BaseAu': model = Baseline_Au(args.au_feat_size) elif args.model == 'BaseImg': model = Baseline_Img(args.img_feat_size) elif args.model == 'EmoBase': model = EmoBase() model = torch.nn.DataParallel(model).cuda() optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) # optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate) # custom optimizer if args.use_sam: base_optim = torch.optim.Adam optimizer = SAM(model.parameters(), base_optim, lr=args.learning_rate) # custom lr scheduler if args.use_cos_wr: scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, T_0=args.cos_wr_t0, T_mult=args.cos_wr_t_mult) elif args.use_cos: scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, args.cos_t_max) elif args.use_multistep: scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, args.step_milestones, args.step_decay) # SWA if args.use_swa: swa_model = torch.optim.swa_utils.AveragedModel(model) swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, swa_lr=args.learning_rate) # ckpt structure {epoch, state_dict, optimizer, best_corr} if args.resume and os.path.isfile(args.resume): print('Load checkpoint:', args.resume) ckpt = torch.load(args.resume) args.start_epoch = ckpt['epoch'] best_corr = ckpt['best_corr'] model.load_state_dict(ckpt['state_dict']) optimizer.load_state_dict(ckpt['optimizer']) print('Loaded ckpt at epoch:', args.start_epoch) # initialize datasets train_loader = torch.utils.data.DataLoader(dataset=EEV_Dataset( csv_path=args.train_csv, vidmap_path=args.train_vidmap, image_feat_path=args.image_features, audio_feat_path=args.audio_features, mode='train', lpfilter=args.lp_filter, train_freq=args.train_freq, val_freq=args.val_freq, cls_indices=args.cls_indices), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=False, drop_last=True) val_loader = torch.utils.data.DataLoader(dataset=EEV_Dataset( csv_path=args.val_csv, vidmap_path=args.val_vidmap, image_feat_path=args.image_features, audio_feat_path=args.audio_features, mode='val', train_freq=args.train_freq, val_freq=args.val_freq, cls_indices=args.cls_indices, repeat_sample=args.repeat_sample), batch_size=None, shuffle=False, num_workers=args.workers, pin_memory=False) accuracy = correlation if args.val_only: print('Run validation ...') print('start epoch:', args.start_epoch, 'model:', args.resume) validate(val_loader, model, accuracy, args.start_epoch, None, None) return log_training = open( os.path.join(args.root_log, args.store_name, 'log.csv'), 'w') with open(os.path.join(args.root_log, args.store_name, 'args.txt'), 'w') as f: f.write(str(args)) tb_writer = SummaryWriter( log_dir=os.path.join(args.root_log, args.store_name)) for epoch in range(args.start_epoch, args.epochs): train(train_loader, model, optimizer, epoch, log_training, tb_writer) # do lr scheduling after epoch if args.use_swa and epoch >= args.swa_start: print('swa stepping...') swa_model.update_parameters(model) swa_scheduler.step() elif args.use_cos_wr or args.use_cos or args.use_multistep: scheduler.step() if (epoch + 1) > 2 and ((epoch + 1) % args.eval_freq == 0 or (epoch + 1) == args.epochs): # validate if args.use_swa and epoch >= args.swa_start: # validate use swa model corr = validate(val_loader, swa_model, accuracy, epoch, log_training, tb_writer) else: corr = validate(val_loader, model, accuracy, epoch, log_training, tb_writer) is_best = corr > best_corr best_corr = max(corr, best_corr) tb_writer.add_scalar('acc/validate_corr_best', best_corr, epoch) output_best = 'Best corr: %.4f\n' % (best_corr) print(output_best) log_training.write(output_best + '\n') log_training.flush() save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_corr': best_corr, }, is_best)
def prepare(args): resume_from_checkpoint = args.resume_from_checkpoint prepare_start_time = time.time() logger.info('global', 'Start preparing.') check_config_dir() logger.info('setting', config_info(), time_report=False) model = Baseline(num_classes=Config.nr_class) logger.info('setting', model_summary(model), time_report=False) logger.info('setting', str(model), time_report=False) train_transforms = transforms.Compose([ transforms.Resize(Config.input_shape), transforms.RandomApply([ transforms.ColorJitter( brightness=0.3, contrast=0.3, saturation=0.3, hue=0) ], p=0.5), transforms.RandomHorizontalFlip(), transforms.Pad(10), transforms.RandomCrop(Config.input_shape), transforms.ToTensor(), transforms.RandomErasing(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) test_transforms = transforms.Compose([ transforms.Resize(Config.input_shape), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) trainset = Veri776_train(transforms=train_transforms, need_attr=True) testset = Veri776_test(transforms=test_transforms, need_attr=True) pksampler = PKSampler(trainset, p=Config.P, k=Config.K) train_loader = torch.utils.data.DataLoader(trainset, batch_size=Config.batch_size, sampler=pksampler, num_workers=Config.nr_worker, pin_memory=True) test_loader = torch.utils.data.DataLoader( testset, batch_size=Config.batch_size, sampler=torch.utils.data.SequentialSampler(testset), num_workers=Config.nr_worker, pin_memory=True) weight_decay_setting = parm_list_with_Wdecay(model) optimizer = torch.optim.Adam(weight_decay_setting, lr=Config.lr) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_multi_func) losses = {} losses['cross_entropy_loss'] = torch.nn.CrossEntropyLoss() losses['type_ce_loss'] = torch.nn.CrossEntropyLoss() losses['color_ce_loss'] = torch.nn.CrossEntropyLoss() losses['triplet_hard_loss'] = triplet_hard_loss( margin=Config.triplet_margin) for k in losses.keys(): losses[k] = losses[k].cuda() start_epoch = 0 if resume_from_checkpoint and os.path.exists(Config.checkpoint_path): checkpoint = load_checkpoint() start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) scheduler.load_state_dict(checkpoint['scheduler']) # continue training for next the epoch of the checkpoint, or simply start from 1 start_epoch += 1 ret = { 'start_epoch': start_epoch, 'model': model, 'train_loader': train_loader, 'test_loader': test_loader, 'optimizer': optimizer, 'scheduler': scheduler, 'losses': losses } prepare_end_time = time.time() time_spent = sec2min_sec(prepare_start_time, prepare_end_time) logger.info( 'global', 'Finish preparing, time spend: {}mins {}s.'.format( time_spent[0], time_spent[1])) return ret