def main_train(): global args, best_corr args.store_name = '{}'.format(args.model) args.store_name = args.store_name + datetime.now().strftime( '_%m-%d_%H-%M-%S') args.start_epoch = 0 if not args.val_only: check_rootfolders(args) if args.model == 'Baseline': if args.cls_indices: model = Baseline(args.img_feat_size, args.au_feat_size, num_classes=len(args.cls_indices)) else: print('Feature size:', args.img_feat_size, args.au_feat_size) model = Baseline(args.img_feat_size, args.au_feat_size) elif args.model == 'TCFPN': model = TCFPN(layers=[48, 64, 96], in_channels=(128), num_classes=15, kernel_size=11) elif args.model == 'BaseAu': model = Baseline_Au(args.au_feat_size) elif args.model == 'BaseImg': model = Baseline_Img(args.img_feat_size) elif args.model == 'EmoBase': model = EmoBase() model = torch.nn.DataParallel(model).cuda() optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay) # optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate) # custom optimizer if args.use_sam: base_optim = torch.optim.Adam optimizer = SAM(model.parameters(), base_optim, lr=args.learning_rate) # custom lr scheduler if args.use_cos_wr: scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, T_0=args.cos_wr_t0, T_mult=args.cos_wr_t_mult) elif args.use_cos: scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, args.cos_t_max) elif args.use_multistep: scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, args.step_milestones, args.step_decay) # SWA if args.use_swa: swa_model = torch.optim.swa_utils.AveragedModel(model) swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, swa_lr=args.learning_rate) # ckpt structure {epoch, state_dict, optimizer, best_corr} if args.resume and os.path.isfile(args.resume): print('Load checkpoint:', args.resume) ckpt = torch.load(args.resume) args.start_epoch = ckpt['epoch'] best_corr = ckpt['best_corr'] model.load_state_dict(ckpt['state_dict']) optimizer.load_state_dict(ckpt['optimizer']) print('Loaded ckpt at epoch:', args.start_epoch) # initialize datasets train_loader = torch.utils.data.DataLoader(dataset=EEV_Dataset( csv_path=args.train_csv, vidmap_path=args.train_vidmap, image_feat_path=args.image_features, audio_feat_path=args.audio_features, mode='train', lpfilter=args.lp_filter, train_freq=args.train_freq, val_freq=args.val_freq, cls_indices=args.cls_indices), batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=False, drop_last=True) val_loader = torch.utils.data.DataLoader(dataset=EEV_Dataset( csv_path=args.val_csv, vidmap_path=args.val_vidmap, image_feat_path=args.image_features, audio_feat_path=args.audio_features, mode='val', train_freq=args.train_freq, val_freq=args.val_freq, cls_indices=args.cls_indices, repeat_sample=args.repeat_sample), batch_size=None, shuffle=False, num_workers=args.workers, pin_memory=False) accuracy = correlation if args.val_only: print('Run validation ...') print('start epoch:', args.start_epoch, 'model:', args.resume) validate(val_loader, model, accuracy, args.start_epoch, None, None) return log_training = open( os.path.join(args.root_log, args.store_name, 'log.csv'), 'w') with open(os.path.join(args.root_log, args.store_name, 'args.txt'), 'w') as f: f.write(str(args)) tb_writer = SummaryWriter( log_dir=os.path.join(args.root_log, args.store_name)) for epoch in range(args.start_epoch, args.epochs): train(train_loader, model, optimizer, epoch, log_training, tb_writer) # do lr scheduling after epoch if args.use_swa and epoch >= args.swa_start: print('swa stepping...') swa_model.update_parameters(model) swa_scheduler.step() elif args.use_cos_wr or args.use_cos or args.use_multistep: scheduler.step() if (epoch + 1) > 2 and ((epoch + 1) % args.eval_freq == 0 or (epoch + 1) == args.epochs): # validate if args.use_swa and epoch >= args.swa_start: # validate use swa model corr = validate(val_loader, swa_model, accuracy, epoch, log_training, tb_writer) else: corr = validate(val_loader, model, accuracy, epoch, log_training, tb_writer) is_best = corr > best_corr best_corr = max(corr, best_corr) tb_writer.add_scalar('acc/validate_corr_best', best_corr, epoch) output_best = 'Best corr: %.4f\n' % (best_corr) print(output_best) log_training.write(output_best + '\n') log_training.flush() save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_corr': best_corr, }, is_best)
def main_train(config, checkpoint_dir=None): global args, best_corr best_corr = 0.0 args.store_name = '{}'.format(args.model) args.store_name = args.store_name + datetime.now().strftime('_%m-%d_%H-%M-%S') args.start_epoch = 0 # check_rootfolders(args) if args.model == 'Baseline': model = Baseline() elif args.model == 'TCFPN': model = TCFPN(layers=[48, 64, 96], in_channels=(2048 + 128), num_classes=15, kernel_size=11) model = torch.nn.DataParallel(model).cuda() if config['optimizer'] == 'adam': optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) elif config['optimizer'] == 'adamw': optimizer = torch.optim.AdamW(model.parameters(), lr=config['lr']) # custom optimizer if args.use_sam: base_optim = torch.optim.Adam optimizer = SAM(model.parameters(), base_optim, lr=config['lr']) # custom lr scheduler if args.use_cos_wr: scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=args.cos_wr_t0,T_mult=args.cos_wr_t_mult) elif args.use_cos: scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.cos_t_max) # SWA if args.use_swa: swa_model = torch.optim.swa_utils.AveragedModel(model) swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, swa_lr=config['lr']) # ckpt structure {epoch, state_dict, optimizer, best_corr} # if args.resume and os.path.isfile(args.resume): # print('Load checkpoint:', args.resume) # ckpt = torch.load(args.resume) # args.start_epoch = ckpt['epoch'] # best_corr = ckpt['best_corr'] # model.load_state_dict(ckpt['state_dict']) # optimizer.load_state_dict(ckpt['optimizer']) # print('Loaded ckpt at epoch:', args.start_epoch) if checkpoint_dir: model_state, optimizer_state = torch.load( os.path.join(checkpoint_dir, "checkpoint")) model.load_state_dict(model_state) optimizer.load_state_dict(optimizer_state) # initialize datasets train_loader = torch.utils.data.DataLoader( dataset=EEV_Dataset( csv_path=args.train_csv, vidmap_path=args.train_vidmap, image_feat_path=args.image_features, audio_feat_path=args.audio_features, mode='train', lpfilter=args.lp_filter ), batch_size=config['batch_size'], shuffle=True, num_workers=args.workers, pin_memory=False, drop_last=True ) val_loader = torch.utils.data.DataLoader( dataset=EEV_Dataset( csv_path=args.val_csv, vidmap_path=args.val_vidmap, image_feat_path=args.image_features, audio_feat_path=args.audio_features, mode='val' ), batch_size=None, shuffle=False, num_workers=args.workers, pin_memory=False ) accuracy = correlation # with open(os.path.join(args.root_log, args.store_name, 'args.txt'), 'w') as f: # f.write(str(args)) # tb_writer = SummaryWriter(log_dir=os.path.join(args.root_log, args.store_name)) for epoch in range(args.start_epoch, args.epochs): # train train(train_loader, model, optimizer, epoch, None, None) # do lr scheduling after epoch if args.use_swa and epoch >= args.swa_start: print('swa stepping...') swa_model.update_parameters(model) swa_scheduler.step() elif args.use_cos_wr: print('cos warm restart (T0:{} Tm:{}) stepping...'.format(args.cos_wr_t0, args.cos_wr_t_mult)) scheduler.step() elif args.use_cos: print('cos (Tmax:{}) stepping...'.format(args.cos_t_max)) scheduler.step() # validate if args.use_swa and epoch >= args.swa_start: # validate use swa model corr, loss = validate(val_loader, swa_model, accuracy, epoch, None, None) else: corr, loss = validate(val_loader, model, accuracy, epoch, None, None) is_best = corr > best_corr best_corr = max(corr, best_corr) # tb_writer.add_scalar('acc/validate_corr_best', best_corr, epoch) # output_best = 'Best corr: %.4f\n' % (best_corr) # print(output_best) # save_checkpoint({ # 'epoch': epoch + 1, # 'state_dict': model.state_dict(), # 'optimizer': optimizer.state_dict(), # 'best_corr': best_corr, # }, is_best) with tune.checkpoint_dir(epoch) as checkpoint_dir: path = os.path.join(checkpoint_dir, "checkpoint") if is_best: path = os.path.join(checkpoint_dir, "checkpoint_best") torch.save((model.state_dict(), optimizer.state_dict()), path) tune.report(loss=loss, accuracy=corr, best_corr=best_corr)
def main_train(): global args, best_corr args.store_name = '{}'.format(args.model) args.store_name = 'zzd' + args.store_name + datetime.now().strftime( '_%m-%d_%H-%M-%S') args.start_epoch = 0 check_rootfolders(args) if args.model == 'Baseline': model = Baseline() model2 = Baseline() elif args.model == 'TCFPN': model = TCFPN(layers=[48, 64, 96], in_channels=(2048 + 128), num_classes=15, kernel_size=11) model = torch.nn.DataParallel(model).cuda() model2 = torch.nn.DataParallel(model2).cuda() # optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate) args.learning_rate = 0.02 print('init: %f' % args.learning_rate) optimizer = torch.optim.SGD([ { 'params': model.parameters(), 'lr': args.learning_rate }, { 'params': model2.parameters(), 'lr': args.learning_rate }, ], weight_decay=1e-4, momentum=0.9, nesterov=True) # custom optimizer if args.use_sam: base_optim = torch.optim.Adam optimizer = SAM(model.parameters(), base_optim, lr=args.learning_rate) # custom lr scheduler #print(args.use_cos_wr) #if args.use_cos_wr: #args.cos_wr_t0 = 10 #print('using Restart: %d'%args.cos_wr_t0) #scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=args.cos_wr_t0,T_mult=2) #elif args.use_cos: scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, args.cos_t_max) # SWA if args.use_swa: swa_model = torch.optim.swa_utils.AveragedModel(model) swa_scheduler = torch.optim.swa_utils.SWALR(optimizer, swa_lr=args.learning_rate) # ckpt structure {epoch, state_dict, optimizer, best_corr} if args.resume and os.path.isfile(args.resume): print('Load checkpoint:', args.resume) ckpt = torch.load(args.resume) args.start_epoch = ckpt['epoch'] best_corr = ckpt['best_corr'] model.load_state_dict(ckpt['state_dict']) optimizer.load_state_dict(ckpt['optimizer']) print('Loaded ckpt at epoch:', args.start_epoch) # initialize datasets train_loader = torch.utils.data.DataLoader(dataset=EEV_Dataset( csv_path=args.train_csv, vidmap_path=args.train_vidmap, image_feat_path=args.image_features, audio_feat_path=args.audio_features, mode='train', lpfilter=args.lp_filter, train_freq=args.train_freq, val_freq=args.val_freq), batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=True, drop_last=True) train_loader2 = torch.utils.data.DataLoader(dataset=EEV_Dataset( csv_path=args.train_csv, vidmap_path=args.train_vidmap, image_feat_path=args.image_features, audio_feat_path=args.audio_features, mode='train', lpfilter=args.lp_filter, train_freq=args.train_freq, val_freq=args.val_freq), batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=True, drop_last=True) train_loader3 = torch.utils.data.DataLoader(dataset=EEV_Dataset( csv_path=args.train_csv, vidmap_path=args.train_vidmap, image_feat_path=args.image_features, audio_feat_path=args.audio_features, mode='train', lpfilter=args.lp_filter, train_freq=args.train_freq, val_freq=args.val_freq), batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=True, drop_last=True) val_loader = torch.utils.data.DataLoader(dataset=EEV_Dataset( csv_path=args.val_csv, vidmap_path=args.val_vidmap, image_feat_path=args.image_features, audio_feat_path=args.audio_features, mode='val', train_freq=args.train_freq, val_freq=args.val_freq), batch_size=None, shuffle=False, num_workers=args.workers, pin_memory=False) accuracy = correlation log_training = open( os.path.join(args.root_log, args.store_name, 'log.csv'), 'w') with open(os.path.join(args.root_log, args.store_name, 'args.txt'), 'w') as f: f.write(str(args)) tb_writer = SummaryWriter( log_dir=os.path.join(args.root_log, args.store_name)) for epoch in range(args.start_epoch, args.epochs): train(train_loader, train_loader2, train_loader3, model, model2, optimizer, epoch, log_training, tb_writer) # do lr scheduling after epoch if args.use_swa and epoch >= args.swa_start: print('swa stepping...') swa_model.update_parameters(model) swa_scheduler.step() elif args.use_cos_wr: print('cos warm restart (T0:{} Tm:{}) stepping...'.format( args.cos_wr_t0, args.cos_wr_t_mult)) scheduler.step() elif args.use_cos: print('cos (Tmax:{}) stepping...'.format(args.cos_t_max)) scheduler.step() if (epoch + 1) % args.eval_freq == 0 or (epoch + 1) == args.epochs: # validate if args.use_swa and epoch >= args.swa_start: # validate use swa model corr = validate(val_loader, swa_model, accuracy, epoch, log_training, tb_writer) else: corr = validate(val_loader, model, accuracy, epoch, log_training, tb_writer) is_best = corr > best_corr best_corr = max(corr, best_corr) tb_writer.add_scalar('acc/validate_corr_best', best_corr, epoch) output_best = 'Best corr: %.4f\n' % (best_corr) print(output_best) log_training.write(output_best + '\n') log_training.flush() save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_corr': best_corr, }, is_best)