def train(train_train_loader, train_test_loader, test_test_loader, modality, naming): assert (modality in ['both', 'rgb', 'flow']) log_dir = os.path.join('logs', naming, modality) if not os.path.exists(log_dir): os.makedirs(log_dir) logger = Logger(log_dir) save_dir = os.path.join('models', naming) if not os.path.exists(save_dir): os.makedirs(save_dir) if modality == 'both': model = BackboneNet(in_features=feature_dim * 2, **model_params).to(device) else: model = BackboneNet(in_features=feature_dim, **model_params).to(device) optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) if learning_rate_decay: scheduler = optim.lr_scheduler.MultiStepLR( optimizer, milestones=[max_step_num // 2], gamma=0.1) optimizer.zero_grad() criterion = nn.CrossEntropyLoss(reduction='elementwise_mean') update_step_idx = 0 single_video_idx = 0 loss_recorder = { 'cls': 0, 'div': 0, 'norm': 0, 'sum': 0, } while update_step_idx < max_step_num: # Train loop for _, data in enumerate(train_train_loader): model.train() single_video_idx += 1 label = data['label'].to(device) weight = data['weight'].to(device).float() if modality == 'both': rgb = data['rgb'].to(device) flow = data['flow'].to(device) model_input = torch.cat([rgb, flow], dim=2) elif modality == 'rgb': model_input = data['rgb'].to(device) else: model_input = data['flow'].to(device) model_input = model_input.transpose(2, 1) _, _, out, scores, _ = model(model_input) loss_cls = criterion(out, label) * weight if diversity_reg: loss_div = get_diversity_loss(scores) * weight loss_div = loss_div * diversity_weight loss_norm = get_norm_regularization(scores) * weight loss_norm = loss_norm * diversity_weight loss = loss_cls + loss_div + loss_norm loss_recorder['div'] += loss_div.item() loss_recorder['norm'] += loss_norm.item() else: loss = loss_cls loss_recorder['cls'] += loss_cls.item() loss_recorder['sum'] += loss.item() loss.backward() # Test and Update if single_video_idx % batch_size == 0: # Test if update_step_idx % log_freq == 0: train_acc, train_loss, train_map = test( model, train_test_loader, modality) logger.scalar_summary('Train Accuracy', train_acc, update_step_idx) logger.scalar_summary('Train map', train_map, update_step_idx) for k in train_loss.keys(): logger.scalar_summary('Train Loss {}'.format(k), train_loss[k], update_step_idx) if args.test_log: test_acc, test_loss, test_map = test( model, test_test_loader, modality) logger.scalar_summary('Test Accuracy', test_acc, update_step_idx) logger.scalar_summary('Test map', test_map, update_step_idx) for k in test_loss.keys(): logger.scalar_summary('Test Loss {}'.format(k), test_loss[k], update_step_idx) # Batch Update update_step_idx += 1 for k, v in loss_recorder.items(): print('Step {}: Loss_{}-{}'.format( update_step_idx, k, v / batch_size)) logger.scalar_summary('Loss_{}_ps'.format(k), v / batch_size, update_step_idx) loss_recorder[k] = 0 optimizer.step() optimizer.zero_grad() if learning_rate_decay: scheduler.step() if update_step_idx in check_points: torch.save( model.state_dict(), os.path.join( save_dir, 'model-{}-{}'.format(modality, update_step_idx))) if update_step_idx >= max_step_num: break
def my_train(train_train_loader, train_test_loader, test_test_loader, modality, naming, label_2_video, num_of_video): assert (modality in ['both', 'rgb', 'flow']) log_dir = os.path.join('logs', naming, modality) if not os.path.exists(log_dir): os.makedirs(log_dir) logger = Logger(log_dir) save_dir = os.path.join('models', naming) if not os.path.exists(save_dir): os.makedirs(save_dir) if modality == 'both': model = BackboneNet(in_features=feature_dim * 2, **model_params).to(device) else: model = BackboneNet(in_features=feature_dim, **model_params).to(device) optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) if learning_rate_decay: scheduler = optim.lr_scheduler.MultiStepLR( optimizer, milestones=[max_step_num // 2], gamma=0.1) optimizer.zero_grad() criterion = nn.CrossEntropyLoss(reduction='mean') update_step_idx = 0 single_video_idx = 0 loss_recorder = { 'cls': 0, 'div': 0, 'norm': 0, 'sum': 0, 'sep': 0, 'trip': 0 } # add all_class = list(label_2_video.keys()) all_class.sort() # add state sample video while update_step_idx < max_step_num: idx, label_idx = get_sample_list(num_of_video, label_2_video, batch_size, n_similar) print("sample index") print(idx) print("label index") print(label_idx) atten_set = [] fea_set = [] global_atten_set = [] base_fea_set = [] loss = torch.Tensor([0]).to(device) for cnt in range(len(idx)): tmp_id = idx[cnt] tmp_label = None if cnt < len(label_idx): tmp_label = label_idx[cnt] data = train_train_loader[tmp_id] model.train() single_video_idx += 1 label = torch.from_numpy(data['label']).long().to(device) weight = torch.from_numpy(data['weight']).float().to(device) if modality == 'both': rgb = torch.from_numpy(data['rgb']).float().to(device) flow = torch.from_numpy(data['flow']).float().to(device) if len(rgb.shape) == 2: rgb = rgb.unsqueeze(0) if len(flow.shape) == 2: flow = flow.unsqueeze(0) model_input = torch.cat([rgb, flow], dim=2) elif modality == 'rgb': model_input = torch.from_numpy( data['rgb']).float().to(device) else: model_input = torch.from_numpy( data['flow']).float().to(device) # print(model_input.shape) if len(model_input.shape) == 2: model_input = model_input.unsqueeze(0) if len(label.shape) == 0: label = label.unsqueeze(0) weight = weight.unsqueeze(0) # print(model_input.shape) model_input = model_input.transpose(2, 1) avg_score, att_weight, out, scores, feature_dict = model( model_input) # add if tmp_label is not None: atten_set.append(avg_score[:, :, tmp_label:tmp_label + 1]) fea_set.append(feature_dict['fuse_feature']) global_atten_set.append(att_weight) base_fea_set.append(feature_dict['base_feature']) loss_cls = criterion(out, label) * weight # add sep flag sep_flag = (single_video_idx % batch_size == 0) if sep_flag: sep_loss_weight = 1 trip_loss_weight = 0 loss = loss + loss_cls loss_recorder['cls'] += loss_cls.item() if len(global_atten_set) > 0: loss_sep = sep_loss_weight * sep_loss( atten_set, fea_set, device) # add separation loss and cluster loss loss_trip = trip_loss_weight * triplet_loss( global_atten_set, base_fea_set, label_idx, device) loss = loss + loss_sep + loss_trip loss_recorder['sep'] = loss_sep.item() loss_recorder['trip'] = loss_trip.item() loss.backward() else: loss = loss + loss_cls loss_recorder['cls'] += loss_cls.item() # loss is the cumulative sum loss_recorder['sum'] = loss.item() # Test and Update if single_video_idx % batch_size == 0: # calculate sep loss # Test if update_step_idx % log_freq == 0: pass # Batch Update update_step_idx += 1 for k, v in loss_recorder.items(): logger.scalar_summary('Loss_{}_ps'.format(k), v / batch_size, update_step_idx) loss_recorder[k] = 0 optimizer.step() optimizer.zero_grad() if learning_rate_decay: scheduler.step() if update_step_idx in check_points: torch.save( model.state_dict(), os.path.join( save_dir, 'model-{}-{}'.format(modality, update_step_idx))) if update_step_idx >= max_step_num: break
def train(train_train_loader, train_test_loader, test_test_loader, modality, naming, lr_val): plt.figure() #For plotting loss graph loss_val = [] step_list = [] assert (modality in ['both', 'rgb', 'flow']) log_dir = os.path.join('logs', naming, modality) if not os.path.exists(log_dir): os.makedirs(log_dir) save_dir = os.path.join('models', naming) if not os.path.exists(save_dir): os.makedirs(save_dir) if modality == 'both': model = BackboneNet(in_features=all_params['feature_dim'] * 2, **all_params['model_params']).to(device) else: model = BackboneNet(in_features=all_params['feature_dim'], **all_params['model_params']).to(device) optimizer = optim.Adam(model.parameters(), lr=lr_val, weight_decay=all_params['weight_decay']) if all_params['learning_rate_decay']: scheduler = optim.lr_scheduler.MultiStepLR( optimizer, milestones=[all_params['max_step_num'] // 2], gamma=0.1) optimizer.zero_grad() criterion = nn.CrossEntropyLoss(reduction='mean') update_step_idx = 0 single_video_idx = 0 loss_recorder = { 'cls': 0, 'div': 0, 'norm': 0, 'sum': 0, } while update_step_idx < all_params['max_step_num']: # Train loop for _, data in enumerate(train_train_loader): model.train() single_video_idx += 1 label = data['label'].to(device) weight = data['weight'].to(device).float() model_input = None if modality == 'both': if data['rgb'].shape[1] != 0 and data['flow'].shape[1] != 0: rgb = data['rgb'].to(device) flow = data['flow'].to(device) model_input = torch.cat([rgb, flow], dim=2) else: continue elif modality == 'rgb': if data['rgb'].shape[1] != 0: model_input = data['rgb'].to(device) else: continue else: if data['flow'].shape[1] != 0: model_input = data['flow'].to(device) else: continue #RGB or Flow if model_input.shape[-1] == all_params['feature_dim']: model_input = model_input.transpose(2, 1) #Both if modality == 'both' and model_input.shape[ -1] == all_params['feature_dim'] * 2: model_input = model_input.transpose(2, 1) _, _, out, scores, _ = model(model_input) loss_cls = criterion(out, label) * weight if all_params['diversity_reg']: loss_div = get_diversity_loss(scores) * weight loss_div = loss_div * all_params['diversity_weight'] loss_norm = get_norm_regularization(scores) * weight loss_norm = loss_norm * all_params['diversity_weight'] loss = loss_cls + loss_div + loss_norm loss_recorder['div'] += loss_div.item() loss_recorder['norm'] += loss_norm.item() else: loss = loss_cls loss_recorder['cls'] += loss_cls.item() loss_recorder['sum'] += loss.item() loss.backward() # Test and Update if single_video_idx % all_params['batch_size'] == 0: # Test if update_step_idx % all_params['log_freq'] == 0: train_acc, train_loss = test(model, train_test_loader, modality) if args.test_log: test_acc, test_loss = test(model, test_test_loader, modality) print('Train Accuracy:{}, Test Accuracy:{}'.format( train_acc, test_acc)) # Batch Update update_step_idx += 1 for k, v in loss_recorder.items(): print('Step {}: Loss_{}-{}'.format( update_step_idx, k, v / all_params['batch_size'])) #Plot loss over every iterations if k == 'sum': step_list.append(update_step_idx) loss_val.append(v / all_params['batch_size']) plt.title('Classification accuracy: {:.4f}'.format( train_acc)) plt.xlabel('Iterations') plt.ylabel('Loss') loss_recorder[k] = 0 if update_step_idx % all_params[ 'log_freq'] == 0: #Plot graph every 500 iterations plt.plot(step_list, loss_val) img = 'Loss_LR-{}.png'.format(lr_val) img_pth = os.path.join(log_dir, img) plt.savefig(img_pth) optimizer.step() optimizer.zero_grad() if all_params['learning_rate_decay']: scheduler.step() if update_step_idx in all_params['check_points']: torch.save( model.state_dict(), os.path.join( save_dir, 'model-{}-{}'.format(modality, update_step_idx))) if update_step_idx >= all_params['max_step_num']: break