def get_d_paretomtl(grads, value, weights, i): """ calculate the gradient direction for ParetoMTL """ # check active constraints current_weight = weights[i] rest_weights = weights w = rest_weights - current_weight gx = torch.matmul(w, value / torch.norm(value)) idx = gx > 0 # calculate the descent direction if torch.sum(idx) <= 0: sol, nd = MinNormSolver.find_min_norm_element( [[grads[t]] for t in range(len(grads))]) return torch.tensor(sol).cuda().float() vec = torch.cat((grads, torch.matmul(w[idx], grads))) sol, nd = MinNormSolver.find_min_norm_element([[vec[t]] for t in range(len(vec))]) weight0 = sol[0] + torch.sum( torch.stack([ sol[j] * w[idx][j - 2, 0] for j in torch.arange(2, 2 + torch.sum(idx)) ])) weight1 = sol[1] + torch.sum( torch.stack([ sol[j] * w[idx][j - 2, 1] for j in torch.arange(2, 2 + torch.sum(idx)) ])) weight = torch.stack([weight0, weight1]) return weight
def get_d_paretomtl(grads, losses, preference_vectors, pref_idx): """ calculate the gradient direction for ParetoMTL Args: grads: flattened gradients for each task losses: values of the losses for each task preference_vectors: all preference vectors u pref_idx: which index of u we are currently using """ # check active constraints current_weight = preference_vectors[pref_idx] rest_weights = preference_vectors w = rest_weights - current_weight gx = torch.matmul(w, losses / torch.norm(losses)) idx = gx > 0 # calculate the descent direction if torch.sum(idx) <= 0: # here there are no active constrains in gx sol, nd = MinNormSolver.find_min_norm_element_FW( [[grads[t]] for t in range(len(grads))]) return torch.tensor(sol).cuda().float() else: # we have active constraints, i.e. we have move too far away from out preference vector #print('optim idx', idx) vec = torch.cat((grads, torch.matmul(w[idx], grads))) sol, nd = MinNormSolver.find_min_norm_element([[vec[t]] for t in range(len(vec)) ]) sol = torch.Tensor(sol).cuda() # FIX: handle more than just 2 objectives n = preference_vectors.shape[1] weights = [] for i in range(n): weight_i = sol[i] + torch.sum( torch.stack([ sol[j] * w[idx][j - n, i] for j in torch.arange(n, n + torch.sum(idx)) ])) weights.append(weight_i) # weight0 = sol[0] + torch.sum(torch.stack([sol[j] * w[idx][j - 2 ,0] for j in torch.arange(2, 2 + torch.sum(idx))])) # weight1 = sol[1] + torch.sum(torch.stack([sol[j] * w[idx][j - 2 ,1] for j in torch.arange(2, 2 + torch.sum(idx))])) # weight = torch.stack([weight0,weight1]) weight = torch.stack(weights) return weight
def get_d_paretomtl_init(grads, losses, preference_vectors, pref_idx): """ calculate the gradient direction for ParetoMTL initialization Args: grads: flattened gradients for each task losses: values of the losses for each task preference_vectors: all preference vectors u pref_idx: which index of u we are currently using Returns: flag: is a feasible initial solution found? weight: """ flag = False nobj = losses.shape # check active constraints, Equation 7 current_pref = preference_vectors[pref_idx] # u_k w = preference_vectors - current_pref # (u_j - u_k) \forall j = 1, ..., K gx = torch.matmul( w, losses / torch.norm(losses)) # In the paper they do not normalize the loss idx = gx > 0 # I(\theta), i.e the indexes of the active constraints active_constraints = w[idx] # constrains which are violated, i.e. gx > 0 # calculate the descent direction if torch.sum(idx) <= 0: flag = True return flag, torch.zeros(nobj) if torch.sum(idx) == 1: sol = torch.ones(1).cuda().float() else: # Equation 9 # w[idx] = set of active constraints, i.e. where the solution is closer to another preference vector than the one desired. gx_gradient = torch.matmul( active_constraints, grads ) # We need to take the derivatives of G_j which is w.dot(grads) sol, nd = MinNormSolver.find_min_norm_element( [[gx_gradient[t]] for t in range(len(gx_gradient))]) sol = torch.Tensor(sol).cuda() # from MinNormSolver we get the weights (alpha) for each gradient. But we need the weights for the losses? weight = torch.matmul(sol, active_constraints) return flag, weight
def get_d_paretomtl_init(grads, value, weights, i): """ calculate the gradient direction for ParetoMTL initialization """ flag = False nobj = value.shape # check active constraints current_weight = weights[i] rest_weights = weights w = rest_weights - current_weight gx = torch.matmul(w, value / torch.norm(value)) idx = gx > 0 # calculate the descent direction if torch.sum(idx) <= 0: flag = True return flag, torch.zeros(nobj) if torch.sum(idx) == 1: sol = torch.ones(1).cuda().float() else: vec = torch.matmul(w[idx], grads) sol, nd = MinNormSolver.find_min_norm_element([[vec[t]] for t in range(len(vec)) ]) # weight0 = torch.sum(torch.stack([sol[j] * w[idx][j ,0] for j in torch.arange(0, torch.sum(idx))])) # weight1 = torch.sum(torch.stack([sol[j] * w[idx][j ,1] for j in torch.arange(0, torch.sum(idx))])) # weight = torch.stack([weight0,weight1]) new_weights = [] for t in range(len(value)): new_weights.append( torch.sum( torch.stack([ sol[j] * w[idx][j, t] for j in torch.arange(0, torch.sum(idx)) ]))) return flag, torch.stack(new_weights)
def train_multi_task(param_file): with open('configs.json') as config_params: configs = json.load(config_params) with open(param_file) as json_params: params = json.load(json_params) exp_identifier = [] for (key, val) in params.items(): if 'tasks' in key: continue exp_identifier+= ['{}={}'.format(key,val)] exp_identifier = '|'.join(exp_identifier) params['exp_id'] = exp_identifier #writer = SummaryWriter(log_dir='runs/{}_{}'.format(params['exp_id'], datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y"))) train_loader, train_dst, val_loader, val_dst = datasets.get_dataset(params, configs) loss_fn = losses.get_loss(params) metric = metrics.get_metrics(params) model = model_selector.get_model(params) model_params = [] for m in model: model_params += model[m].parameters() if 'RMSprop' in params['optimizer']: optimizer = torch.optim.RMSprop(model_params, lr=params['lr']) elif 'Adam' in params['optimizer']: optimizer = torch.optim.Adam(model_params, lr=params['lr']) elif 'SGD' in params['optimizer']: optimizer = torch.optim.SGD(model_params, lr=params['lr'], momentum=0.9) tasks = params['tasks'] all_tasks = configs[params['dataset']]['all_tasks'] print('Starting training with parameters \n \t{} \n'.format(str(params))) if 'mgda' in params['algorithm']: approximate_norm_solution = params['use_approximation'] if approximate_norm_solution: print('Using approximate min-norm solver') else: print('Using full solver') n_iter = 0 loss_init = {} for epoch in tqdm(range(NUM_EPOCHS)): start = timer() print('Epoch {} Started'.format(epoch)) if (epoch+1) % 10 == 0: # Every 50 epoch, half the LR for param_group in optimizer.param_groups: param_group['lr'] *= 0.85 print('Half the learning rate{}'.format(n_iter)) for m in model: model[m].train() for batch in train_loader: n_iter += 1 # First member is always images images = batch[0] images = Variable(images.cuda()) labels = {} # Read all targets of all tasks for i, t in enumerate(all_tasks): if t not in tasks: continue labels[t] = batch[i+1] labels[t] = Variable(labels[t].cuda()) # Scaling the loss functions based on the algorithm choice loss_data = {} grads = {} scale = {} mask = None masks = {} if 'mgda' in params['algorithm']: # Will use our MGDA_UB if approximate_norm_solution is True. Otherwise, will use MGDA if approximate_norm_solution: optimizer.zero_grad() # First compute representations (z) images_volatile = Variable(images.data, volatile=True) rep, mask = model['rep'](images_volatile, mask) # As an approximate solution we only need gradients for input if isinstance(rep, list): # This is a hack to handle psp-net rep = rep[0] rep_variable = [Variable(rep.data.clone(), requires_grad=True)] list_rep = True else: rep_variable = Variable(rep.data.clone(), requires_grad=True) list_rep = False # Compute gradients of each loss function wrt z for t in tasks: optimizer.zero_grad() out_t, masks[t] = model[t](rep_variable, None) loss = loss_fn[t](out_t, labels[t]) loss_data[t] = loss.data[0] loss.backward() grads[t] = [] if list_rep: grads[t].append(Variable(rep_variable[0].grad.data.clone(), requires_grad=False)) rep_variable[0].grad.data.zero_() else: grads[t].append(Variable(rep_variable.grad.data.clone(), requires_grad=False)) rep_variable.grad.data.zero_() else: # This is MGDA for t in tasks: # Comptue gradients of each loss function wrt parameters optimizer.zero_grad() rep, mask = model['rep'](images, mask) out_t, masks[t] = model[t](rep, None) loss = loss_fn[t](out_t, labels[t]) loss_data[t] = loss.data[0] loss.backward() grads[t] = [] for param in model['rep'].parameters(): if param.grad is not None: grads[t].append(Variable(param.grad.data.clone(), requires_grad=False)) # Normalize all gradients, this is optional and not included in the paper. gn = gradient_normalizers(grads, loss_data, params['normalization_type']) for t in tasks: for gr_i in range(len(grads[t])): grads[t][gr_i] = grads[t][gr_i] / gn[t] # Frank-Wolfe iteration to compute scales. sol, min_norm = MinNormSolver.find_min_norm_element([grads[t] for t in tasks]) for i, t in enumerate(tasks): scale[t] = float(sol[i]) else: for t in tasks: masks[t] = None scale[t] = float(params['scales'][t]) # Scaled back-propagation optimizer.zero_grad() rep, _ = model['rep'](images, mask) for i, t in enumerate(tasks): out_t, _ = model[t](rep, masks[t]) loss_t = loss_fn[t](out_t, labels[t]) loss_data[t] = loss_t.data[0] if i > 0: loss = loss + scale[t]*loss_t else: loss = scale[t]*loss_t loss.backward() optimizer.step() writer.add_scalar('training_loss', loss.data[0], n_iter) for t in tasks: writer.add_scalar('training_loss_{}'.format(t), loss_data[t], n_iter) for m in model: model[m].eval() tot_loss = {} tot_loss['all'] = 0.0 met = {} for t in tasks: tot_loss[t] = 0.0 met[t] = 0.0 num_val_batches = 0 for batch_val in val_loader: val_images = Variable(batch_val[0].cuda(), volatile=True) labels_val = {} for i, t in enumerate(all_tasks): if t not in tasks: continue labels_val[t] = batch_val[i+1] labels_val[t] = Variable(labels_val[t].cuda(), volatile=True) val_rep, _ = model['rep'](val_images, None) for t in tasks: out_t_val, _ = model[t](val_rep, None) loss_t = loss_fn[t](out_t_val, labels_val[t]) tot_loss['all'] += loss_t.data[0] tot_loss[t] += loss_t.data[0] metric[t].update(out_t_val, labels_val[t]) num_val_batches+=1 for t in tasks: writer.add_scalar('validation_loss_{}'.format(t), tot_loss[t]/num_val_batches, n_iter) metric_results = metric[t].get_result() for metric_key in metric_results: writer.add_scalar('metric_{}_{}'.format(metric_key, t), metric_results[metric_key], n_iter) metric[t].reset() writer.add_scalar('validation_loss', tot_loss['all']/len(val_dst), n_iter) if epoch % 3 == 0: # Save after every 3 epoch state = {'epoch': epoch+1, 'model_rep': model['rep'].state_dict(), 'optimizer_state' : optimizer.state_dict()} for t in tasks: key_name = 'model_{}'.format(t) state[key_name] = model[t].state_dict() torch.save(state, "saved_models/{}_{}_model.pkl".format(params['exp_id'], epoch+1)) end = timer() print('Epoch ended in {}s'.format(end - start))
def step(self, batch): # Scaling the loss functions based on the algorithm choice # loss_data = {} # grads = {} # scale = {} # mask = None # masks = {} # Will use our MGDA_UB if approximate_norm_solution is True. Otherwise, will use MGDA if self.approximate_norm_solution: self.model.zero_grad() # First compute representations (z) with torch.no_grad(): # images_volatile = Variable(images.data, volatile=True) # rep, mask = model['rep'](images_volatile, mask) rep = self.model.forward_feature_extraction(batch) # As an approximate solution we only need gradients for input # if isinstance(rep, list): # # This is a hack to handle psp-net # rep = rep[0] # rep_variable = [Variable(rep.data.clone(), requires_grad=True)] # list_rep = True # else: # rep_variable = Variable(rep.data.clone(), requires_grad=True) # list_rep = False # Compute gradients of each loss function wrt z gradients = [] obj_values = [] for i, objective in enumerate(self.objectives): # zero grad self.model.zero_grad() logits = self.model.forward_linear(rep, i) batch.update(logits) output = objective(**batch) output.backward() obj_values.append(output.item()) gradients.append({}) private_params = self.model.private_params() if hasattr( self.model, 'private_params') else [] for name, param in self.model.named_parameters(): not_private = all([p not in name for p in private_params]) if not_private and param.requires_grad and param.grad is not None: gradients[i][name] = param.grad.data.detach().clone() param.grad = None self.model.zero_grad() grads = gradients # for t in tasks: # self.model.zero_grad() # out_t, masks[t] = model[t](rep_variable, None) # loss = loss_fn[t](out_t, labels[t]) # loss_data[t] = loss.data[0] # loss.backward() # grads[t] = [] # if list_rep: # grads[t].append(Variable(rep_variable[0].grad.data.clone(), requires_grad=False)) # rep_variable[0].grad.data.zero_() # else: # grads[t].append(Variable(rep_variable.grad.data.clone(), requires_grad=False)) # rep_variable.grad.data.zero_() else: # This is MGDA grads, obj_values = calc_gradients(batch, self.model, self.objectives) # for t in tasks: # # Comptue gradients of each loss function wrt parameters # self.model.zero_grad() # rep, mask = model['rep'](images, mask) # out_t, masks[t] = model[t](rep, None) # loss = loss_fn[t](out_t, labels[t]) # loss_data[t] = loss.data[0] # loss.backward() # grads[t] = [] # for param in self.model['rep'].parameters(): # if param.grad is not None: # grads[t].append(Variable(param.grad.data.clone(), requires_grad=False)) # Normalize all gradients, this is optional and not included in the paper. gn = gradient_normalizers(grads, obj_values, self.normalization_type) for t in range(len(self.objectives)): for gr_i in grads[t]: grads[t][gr_i] = grads[t][gr_i] / gn[t] # Frank-Wolfe iteration to compute scales. grads = [[v for v in d.values()] for d in grads] sol, min_norm = MinNormSolver.find_min_norm_element(grads) # for i, t in enumerate(range(len(self.objectives))): # scale[t] = float(sol[i]) # Scaled back-propagation self.model.zero_grad() logits = self.model(batch) batch.update(logits) loss_total = None for a, objective in zip(sol, self.objectives): task_loss = objective(**batch) loss_total = a * task_loss if not loss_total else loss_total + a * task_loss loss_total.backward() return loss_total.item(), 0
def get_d_mgda(vec): r"""Calculate the gradient direction for MGDA.""" sol, nd = MinNormSolver.find_min_norm_element([[vec[t]] for t in range(len(vec))]) return torch.tensor(sol).cuda().float()
task_losses.append(task_loss[0].item()) loss_data[t] = task_loss.data task_loss.backward() grads[t] = [] grads[t].append( Variable(rep_variable.grad.data.clone(), requires_grad=False)) rep_variable.grad.data.zero_() # Normalize all gradients, this is optional and not included in the paper. gn = gradient_normalizers(grads, loss_data, 'none') for t in range(num_tasks): for gr_i in range(len(grads[t])): grads[t][gr_i] = grads[t][gr_i] / gn[t] # Frank-Wolfe iteration to compute scales. sol, min_norm = MinNormSolver.find_min_norm_element( [grads[t] for t in range(num_tasks)]) for t in range(num_tasks): scale[t] = float(sol[t]) # Scaled back-propagation optimizer.zero_grad() feats = model.forward_shared(train_data) for t in range(num_tasks): out_t = model.forward_task(feats, t) if t < 13: label = train_label[:, t, None, :, :] elif t == 13: label = train_depth else: label = train_normal loss_t = model.model_fit(out_t, label, t)
def compute_loss_coeff(optimizers, device, future_frames_num, all_disp_field_gt, all_valid_pixel_maps, pixel_cat_map_gt, disp_pred, criterion, non_empty_map, class_pred, motion_gt, motion_pred, shared_feats_tensor): encoder_optimizer = optimizers[0] head_optimizer = optimizers[1] encoder_optimizer.zero_grad() head_optimizer.zero_grad() grads = {} # Compute the displacement loss gt = all_disp_field_gt[:, -future_frames_num:, ...].contiguous() gt = gt.view(-1, gt.size(2), gt.size(3), gt.size(4)) gt = gt.permute(0, 3, 1, 2).to(device) valid_pixel_maps = all_valid_pixel_maps[:, -future_frames_num:, ...].contiguous() valid_pixel_maps = valid_pixel_maps.view(-1, valid_pixel_maps.size(2), valid_pixel_maps.size(3)) valid_pixel_maps = torch.unsqueeze(valid_pixel_maps, 1) valid_pixel_maps = valid_pixel_maps.to(device) valid_pixel_num = torch.nonzero(valid_pixel_maps).size(0) if valid_pixel_num == 0: return [None] * 3 # --------------------------------------------------------------------- # -- Generate the displacement w.r.t. the keyframe if pred_adj_frame_distance: disp_pred = disp_pred.view(-1, future_frames_num, disp_pred.size(-3), disp_pred.size(-2), disp_pred.size(-1)) for c in range(1, disp_pred.size(1)): disp_pred[:, c, ...] = disp_pred[:, c, ...] + disp_pred[:, c - 1, ...] disp_pred = disp_pred.view(-1, disp_pred.size(-3), disp_pred.size(-2), disp_pred.size(-1)) # --------------------------------------------------------------------- # -- Compute the masked displacement loss # Note: have also tried focal loss, but did not observe noticeable improvement pixel_cat_map_gt_numpy = pixel_cat_map_gt.numpy() pixel_cat_map_gt_numpy = np.argmax(pixel_cat_map_gt_numpy, axis=-1) + 1 cat_weight_map = np.zeros_like(pixel_cat_map_gt_numpy, dtype=np.float32) weight_vector = [0.005, 1.0, 1.0, 1.0, 1.0] # [bg, car & bus, ped, bike, other] for k in range(5): mask = pixel_cat_map_gt_numpy == (k + 1) cat_weight_map[mask] = weight_vector[k] cat_weight_map = cat_weight_map[:, np.newaxis, np.newaxis, ...] # (batch, 1, 1, h, w) cat_weight_map = torch.from_numpy(cat_weight_map).to(device) map_shape = cat_weight_map.size() loss_disp = criterion(gt * valid_pixel_maps, disp_pred * valid_pixel_maps) loss_disp = loss_disp.view(map_shape[0], -1, map_shape[-3], map_shape[-2], map_shape[-1]) loss_disp = torch.sum(loss_disp * cat_weight_map) / valid_pixel_num encoder_optimizer.zero_grad() head_optimizer.zero_grad() loss_disp.backward(retain_graph=True) grads[0] = [] grads[0].append(shared_feats_tensor.grad.data.clone().detach()) shared_feats_tensor.grad.data.zero_() # --------------------------------------------------------------------- # -- Compute the grid cell classification loss non_empty_map = non_empty_map.view(-1, 256, 256) non_empty_map = non_empty_map.to(device) pixel_cat_map_gt = pixel_cat_map_gt.permute(0, 3, 1, 2).to(device) log_softmax_probs = F.log_softmax(class_pred, dim=1) map_shape = cat_weight_map.size() cat_weight_map = cat_weight_map.view(map_shape[0], map_shape[-2], map_shape[-1]) # (bs, h, w) loss_class = torch.sum(-pixel_cat_map_gt * log_softmax_probs, dim=1) * cat_weight_map loss_class = torch.sum( loss_class * non_empty_map) / torch.nonzero(non_empty_map).size(0) encoder_optimizer.zero_grad() head_optimizer.zero_grad() loss_class.backward(retain_graph=True) grads[1] = [] grads[1].append(shared_feats_tensor.grad.data.clone().detach()) shared_feats_tensor.grad.data.zero_() # --------------------------------------------------------------------- # -- Compute the speed level classification loss motion_gt_numpy = motion_gt.numpy() motion_gt = motion_gt.permute(0, 3, 1, 2).to(device) log_softmax_motion_pred = F.log_softmax(motion_pred, dim=1) motion_gt_numpy = np.argmax(motion_gt_numpy, axis=-1) + 1 motion_weight_map = np.zeros_like(motion_gt_numpy, dtype=np.float32) weight_vector = [0.005, 1.0] for k in range(2): mask = motion_gt_numpy == (k + 1) motion_weight_map[mask] = weight_vector[k] motion_weight_map = torch.from_numpy(motion_weight_map).to(device) loss_speed = torch.sum(-motion_gt * log_softmax_motion_pred, dim=1) * motion_weight_map loss_motion = torch.sum( loss_speed * non_empty_map) / torch.nonzero(non_empty_map).size(0) encoder_optimizer.zero_grad() head_optimizer.zero_grad() loss_motion.backward(retain_graph=True) grads[2] = [] grads[2].append(shared_feats_tensor.grad.data.clone().detach()) shared_feats_tensor.grad.data.zero_() # --------------------------------------------------------------------- # -- Frank-Wolfe iteration to compute scales. scale = np.zeros(3, dtype=np.float32) sol, min_norm = MinNormSolver.find_min_norm_element( [grads[t] for t in range(3)]) for i in range(3): scale[i] = float(sol[i]) return scale
def train_multi_task(params, fold=0): with open('configs.json') as config_params: configs = json.load(config_params) # with open(param_file) as json_params: # params = json.load(json_params) exp_identifier = [] for (key, val) in params.items(): if ('tasks' in key) or ('dataset' in key) or ('normalization_type' in key) \ or ('grid_search' in key) or ('train' in key) or ('test' in key): continue exp_identifier+= ['{}={}'.format(key,val)] exp_identifier = '|'.join(exp_identifier) # params['exp_id'] = exp_identifier if params['train'] : train_loader, train_dst, val_loader, val_dst = dataset_selector.get_dataset(params, configs, fold) writer = SummaryWriter(log_dir='5fold_runs/{}_{}'.format(exp_identifier, datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y"))) if params['test'] : test_loader, test_dst = dataset_selector.get_dataset(params, configs) loss_fn = loss_selector.get_loss(params) metric = metrics_selector.get_metrics(params) model = model_selector.get_model(params) model_params = [] model_params_num = 0 for m in model: model_params += model[m].parameters() for parameter in model[m].parameters(): # print('parameter:') # print(parameter) model_params_num += parameter.numel() # print('model params num:') # print(model_params_num) if 'RMSprop' in params['optimizer']: optimizer = torch.optim.RMSprop(model_params, lr=params['lr']) elif 'Adam' in params['optimizer']: optimizer = torch.optim.Adam(model_params, lr=params['lr']) elif 'SGD' in params['optimizer']: optimizer = torch.optim.SGD(model_params, lr=params['lr'], momentum=0.9) tasks = params['tasks'] all_tasks = configs[params['dataset']]['all_tasks'] print('Starting training with parameters \n \t{} \n'.format(str(params))) n_iter = 0 loss_init = {} # early stopping count = 0 init_val_plcc = 0 best_epoch = 0 # train if params['train'] : for epoch in tqdm(range(NUM_EPOCHS)): start = timer() print('Epoch {} Started'.format(epoch)) if (epoch+1) % 30 == 0: # Every 30 epoch, half the LR for param_group in optimizer.param_groups: param_group['lr'] *= 0.5 print('Half the learning rate {}'.format(n_iter)) for m in model: model[m].train() for batch in train_loader: n_iter += 1 # First member is always images images = batch[0] images = Variable(images.cuda()) labels = {} # Read all targets of all tasks for i, t in enumerate(all_tasks): if t not in tasks: continue labels[t] = batch[i+1] labels[t] = Variable(labels[t].cuda()) # Scaling the loss functions based on the algorithm choice loss_data = {} grads = {} scale = {} mask = None masks = {} # use algo MGDA_UB optimizer.zero_grad() # First compute representations (z) with torch.no_grad(): images_volatile = Variable(images.data) rep, mask = model['rep'](images_volatile, mask) # As an approximate solution we only need gradients for input rep_variable = Variable(rep.data.clone(), requires_grad=True) list_rep = False # Compute gradients of each loss function wrt z for t in tasks: optimizer.zero_grad() out_t, masks[t] = model[t](rep_variable, None) loss = loss_fn[t](out_t, labels[t]) loss_data[t] = loss.item() loss.backward() grads[t] = Variable(rep_variable.grad.data.clone(), requires_grad=False) rep_variable.grad.data.zero_() # Normalize all gradients, this is optional and not included in the paper. gn = gradient_normalizers(grads, loss_data, params['normalization_type']) for t in tasks: for gr_i in range(len(grads[t])): grads[t][gr_i] = grads[t][gr_i] / gn[t] # Frank-Wolfe iteration to compute scales. sol, min_norm = MinNormSolver.find_min_norm_element([grads[t] for t in tasks]) for i, t in enumerate(tasks): scale[t] = float(sol[i]) # Scaled back-propagation optimizer.zero_grad() rep, _ = model['rep'](images, mask) for i, t in enumerate(tasks): out_t, _ = model[t](rep, masks[t]) loss_t = loss_fn[t](out_t, labels[t]) loss_data[t] = loss_t.item() if i > 0: loss = loss + scale[t]*loss_t else: loss = scale[t]*loss_t loss.backward() optimizer.step() writer.add_scalar('training_loss', loss.item(), n_iter) for t in tasks: writer.add_scalar('training_loss_{}'.format(t), loss_data[t], n_iter) # validation for m in model: model[m].eval() tot_loss = {} tot_loss['all'] = 0.0 met = {} for t in tasks: tot_loss[t] = 0.0 met[t] = 0.0 num_val_batches = 0 for batch_val in val_loader: with torch.no_grad(): val_images = Variable(batch_val[0].cuda()) labels_val = {} for i, t in enumerate(all_tasks): if t not in tasks: continue labels_val[t] = batch_val[i+1] with torch.no_grad(): labels_val[t] = Variable(labels_val[t].cuda()) val_rep, _ = model['rep'](val_images, None) for t in tasks: out_t_val, _ = model[t](val_rep, None) loss_t = loss_fn[t](out_t_val, labels_val[t]) tot_loss['all'] += loss_t.item() tot_loss[t] += loss_t.item() metric[t].update(out_t_val, labels_val[t]) num_val_batches+=1 avg_plcc = 0 for t in tasks: writer.add_scalar('validation_loss_{}'.format(t), tot_loss[t]/num_val_batches, n_iter) metric_results = metric[t].get_result() avg_plcc += metric_results['plcc'] metric_str = 'task_{} : '.format(t) for metric_key in metric_results: writer.add_scalar('metric_{}_{}'.format(metric_key, t), metric_results[metric_key], n_iter) metric_str += '{} = {} '.format(metric_key, metric_results[metric_key]) metric[t].reset() metric_str += 'loss = {}'.format(tot_loss[t]/num_val_batches) print(metric_str) print('all loss = {}'.format(tot_loss['all']/len(val_dst))) writer.add_scalar('validation_loss', tot_loss['all']/len(val_dst), n_iter) avg_plcc /= 4 print(avg_plcc) print(init_val_plcc) if init_val_plcc < avg_plcc: init_val_plcc = avg_plcc # save model weights if val loss decreases print('Saving model...') state = {'epoch': epoch+1, 'model_rep': model['rep'].state_dict(), 'optimizer_state' : optimizer.state_dict()} for t in tasks: key_name = 'model_{}'.format(t) state[key_name] = model[t].state_dict() torch.save(state, "saved_models/{}_{}_{}_model.pkl".format(exp_identifier, epoch+1, fold)) best_epoch = epoch + 1 # reset count count = 0 elif init_val_plcc >= avg_plcc: count += 1 if count == 10: print('Val EMD loss has not decreased in %d epochs. Training terminated.' % 10) break end = timer() print('Epoch ended in {}s'.format(end - start)) print('Training completed.') return exp_identifier, init_val_plcc, best_epoch # test if params['test'] : state = torch.load(os.path.join('./saved_models', "{}_{}_{}_model.pkl".format(params['exp_identifier'], params['best_epoch'], params['best_fold']))) model['rep'].load_state_dict(state['model_rep']) for t in tasks : key_name = 'model_{}'.format(t) model[t].load_state_dict(state[key_name]) print('Successfully loaded {}_{}_{}_model'.format(params['exp_identifier'], params['best_epoch'], params['best_fold'])) for m in model: model[m].eval() test_tot_loss = {} test_tot_loss['all'] = 0.0 test_met = {} for t in tasks: test_tot_loss[t] = 0.0 test_met[t] = 0.0 num_test_batches = 0 for batch_test in test_loader: with torch.no_grad(): test_images = Variable(batch_test[0].cuda()) labels_test = {} for i, t in enumerate(all_tasks): if t not in tasks: continue labels_test[t] = batch_test[i+1] with torch.no_grad(): labels_test[t] = Variable(labels_test[t].cuda()) test_rep, _ = model['rep'](test_images, None) for t in tasks: out_t_test, _ = model[t](test_rep, None) test_loss_t = loss_fn[t](out_t_test, labels_test[t]) test_tot_loss['all'] += test_loss_t.item() test_tot_loss[t] += test_loss_t.item() metric[t].update(out_t_test, labels_test[t]) num_test_batches+=1 print('test:') for t in tasks: test_metric_results = metric[t].get_result() test_metric_str = 'task_{} : '.format(t) for metric_key in test_metric_results: test_metric_str += '{} = {} '.format(metric_key, test_metric_results[metric_key]) metric[t].reset() # test_metric_str += 'loss = {}'.format(test_tot_loss[t]/num_test_batches) print(test_metric_str)
def _backward_step(self, input_train, target_train, input_valid, target_valid, eta, network_optimizer): grads = {} loss_data = {} self.optimizer.zero_grad() if self.args.unrolled: unrolled_model = self._compute_unrolled_model( input_train, target_train, eta, network_optimizer) else: unrolled_model = self.model if self.args.adv_outer: input_valid = Variable(input_valid.data, requires_grad=True).cuda() # ---- acc loss ---- unrolled_loss = unrolled_model._loss(input_valid, target_valid) loss_data['acc'] = unrolled_loss.data[0] / 2 # lossNorm grads['acc'] = list( torch.autograd.grad(unrolled_loss, unrolled_model.arch_parameters(), retain_graph=True)) # ---- acc loss end ---- # ---- adv loss ---- if self.args.adv_outer and (self.epoch >= self.args.adv_later): step_size = self.epsilon * 1.25 delta = ((torch.rand(input_valid.size()) - 0.5) * 2).cuda() * self.epsilon adv_grad = torch.autograd.grad(unrolled_loss, input_valid, retain_graph=True, create_graph=False)[0] adv_grad = adv_grad.detach().data delta = clamp(delta + step_size * torch.sign(adv_grad), -self.epsilon, self.epsilon) delta = clamp(delta, self.lower_limit - input_valid.data, self.upper_limit - input_valid.data) adv_input = Variable(input_valid.data + delta, requires_grad=False).cuda() self.optimizer.zero_grad() unrolled_loss_adv = unrolled_model._loss(adv_input, target_valid) grads['adv'] = list( torch.autograd.grad(unrolled_loss_adv, unrolled_model.arch_parameters(), retain_graph=True)) loss_data['adv'] = unrolled_loss_adv.data[0] / 2 # lossNorm # ---- adv loss end ---- # ---- param loss ---- if self.args.nop_outer and (self.epoch >= self.args.nop_later): self.optimizer.zero_grad() param_loss = self.param_number(unrolled_model) loss_data['nop'] = param_loss.data[0] grads['nop'] = list( torch.autograd.grad(param_loss, unrolled_model.arch_parameters(), retain_graph=True)) # ---- param loss end ---- # ---- ood loss ---- if self.args.ood_outer: self.optimizer.zero_grad() ood_logits = unrolled_model.forward(self.ood_input) ood_loss = F.kl_div(input=F.log_softmax(ood_logits), target=torch.ones_like(ood_logits) / ood_logits.size()[-1]) ood_loss = ood_loss * 50 # lossNorm, 10 loss_data['ood'] = ood_loss.data[0] grads['ood'] = list( torch.autograd.grad(ood_loss, unrolled_model.arch_parameters(), retain_graph=True)) # ---- ood loss end ---- # ---- flops loss ---- if self.args.flp_outer: self.optimizer.zero_grad() flp_loss = self.cal_flops(unrolled_model) loss_data['flp'] = flp_loss.data[0] grads['flp'] = list( torch.autograd.grad(flp_loss, unrolled_model.arch_parameters(), retain_graph=True)) # ---- flops loss end ---- gn = gradient_normalizers( grads, loss_data, normalization_type=self.args.grad_norm) # loss+, loss, l2 for t in grads: for gr_i in range(len(grads[t])): grads[t][gr_i] = grads[t][gr_i] / (gn[t] + 1e-7) # ---- MGDA ----- if self.args.MGDA and (len(grads) > 1): sol, _ = MinNormSolver.find_min_norm_element( [grads[t] for t in grads]) sol = [x + 1e-7 for x in sol] else: sol = [1] * len(grads) # print(sol) # acc, adv, nop loss = 0 for kk, t in enumerate(grads): if t == 'acc': loss += float(sol[kk]) * unrolled_loss elif t == 'adv': loss += float(sol[kk]) * unrolled_loss_adv elif t == 'nop': loss += float(sol[kk]) * param_loss elif t == 'ood': loss += float(sol[kk]) * ood_loss elif t == 'flp': loss += float(sol[kk]) * flp_loss self.optimizer.zero_grad() loss.backward() # ---- MGDA end ----- if self.args.unrolled: dalpha = [v.grad for v in unrolled_model.arch_parameters()] vector = [v.grad.data for v in unrolled_model.parameters()] implicit_grads = self._hessian_vector_product( vector, input_train, target_train) for g, ig in zip(dalpha, implicit_grads): g.data.sub_(eta, ig.data) for v, g in zip(self.model.arch_parameters(), dalpha): if v.grad is None: v.grad = Variable(g.data) else: v.grad.data.copy_(g.data) # aa = [[gr.pow(2).sum().data[0] for gr in grads[t]] for t in grads] logs = namedtuple("logs", ['sol', 'loss_data'])(sol, loss_data) # logs.sol = sol # logs.param_loss = param_loss print(logs) return logs
def train(train_loader, nets, optimizer, criterions, epoch): batch_time = AverageMeter() data_time = AverageMeter() cls_losses = AverageMeter() half_losses = AverageMeter() st_losses = AverageMeter() collision_losses = AverageMeter() min_losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() at_losses = AverageMeter() snet = nets['snet'] tnet = nets['tnet'] criterionCls = criterions['criterionCls'] criterionST = criterions['criterionST'] snet.train() end = time.time() for idx, (img, target) in enumerate(train_loader, start=1): data_time.update(time.time() - end) if args.cuda: img = img.cuda() target = target.cuda() img = Variable(img) optimizer.zero_grad() with torch.no_grad(): images_volatile = Variable(img.data) _, _, _, _, output_s = snet(img) _, _, _, _, output_t = tnet(img) if isinstance(output_s, list): output_s = output_s[0] output_s_variable = [ Variable(output_s.data.clone(), requires_grad=True) ] list_rep = True else: output_s_variable = Variable(output_s.data.clone(), requires_grad=True) list_rep = False optimizer.zero_grad() target_reshape = target.reshape(-1, 1) target_onehot = torch.FloatTensor(output_s.shape[0], output_s.shape[1]).cuda() target_onehot.zero_() target_onehot.scatter_(1, target_reshape, 1) p = F.softmax(output_s / args.T, dim=1) q = F.softmax(output_t / args.T, dim=1) loss_data = {} grads = {} optimizer.zero_grad() _, _, _, _, output_s = snet(img) p = F.softmax(output_s / args.T, dim=1) loss_ce = renyi_distill('shannon')(F.softmax(output_s, dim=1), target_onehot) loss_data[0] = loss_ce.data.item() loss_ce.backward() grads[0] = [] for param in snet.parameters(): if param.grad is not None: grads[0].append( Variable(param.grad.data.clone(), requires_grad=False).reshape(-1)) optimizer.zero_grad() _, _, _, _, output_s = snet(img) p = F.softmax(output_s / args.T, dim=1) loss_half = renyi_distill('half-fixed')(p, q) * (args.T**2) loss_data[1] = loss_half.data.item() loss_half.backward(retain_graph=True) grads[1] = [] for param in snet.parameters(): if param.grad is not None: grads[1].append( Variable(param.grad.data.clone(), requires_grad=False).reshape(-1)) optimizer.zero_grad() _, _, _, _, output_s = snet(img) p = F.softmax(output_s / args.T, dim=1) loss_shannon = renyi_distill('shannon')(p, q) * (args.T**2) loss_data[2] = loss_shannon.data.item() loss_shannon.backward(retain_graph=True) grads[2] = [] for param in snet.parameters(): if param.grad is not None: grads[2].append( Variable(param.grad.data.clone(), requires_grad=False).reshape(-1)) optimizer.zero_grad() _, _, _, _, output_s = snet(img) p = F.softmax(output_s / args.T, dim=1) loss_collision = renyi_distill('collision')(p, q) * (args.T**2) loss_data[3] = loss_collision.data.item() loss_collision.backward(retain_graph=True) grads[3] = [] for param in snet.parameters(): if param.grad is not None: grads[3].append( Variable(param.grad.data.clone(), requires_grad=False).reshape(-1)) optimizer.zero_grad() _, _, _, _, output_s = snet(img) p = F.softmax(output_s / args.T, dim=1) loss_min = renyi_distill('min')(p, q) * (args.T**2) loss_data[4] = loss_min.data.item() loss_min.backward(retain_graph=True) grads[4] = [] for param in snet.parameters(): if param.grad is not None: grads[4].append( Variable(param.grad.data.clone(), requires_grad=False).reshape(-1)) gn = gradient_normalizers(grads, loss_data, 'l2') for t in range(5): for gr_i in range(len(grads[t])): grads[t][gr_i] = grads[t][gr_i] / gn[t] sol, min = MinNormSolver.find_min_norm_element( [grads[t] for t in range(5)]) scale = {} for t in range(5): scale[t] = float(sol[t]) prec1, prec5 = accuracy(output_s, target, topk=(1, 5)) cls_losses.update(loss_ce.item(), img.size(0)) half_losses.update(loss_half.item(), img.size(0)) st_losses.update(loss_shannon.item(), img.size(0)) collision_losses.update(loss_collision.item(), img.size(0)) min_losses.update(loss_min.item(), img.size(0)) top1.update(prec1.item(), img.size(0)) top5.update(prec5.item(), img.size(0)) optimizer.zero_grad() _, rb1_s, rb2_s, rb3_s, output_s = snet(img) loss_ce = renyi_distill('shannon')(F.softmax(output_s, dim=1), target_onehot) loss_data[0] = loss_ce.data.item() loss = scale[0] * loss_ce loss_half = renyi_distill('half')(p, q) * (args.T**2) loss_data[1] = loss_half.data.item() loss = loss + scale[1] * loss_half loss_shannon = renyi_distill('shannon')(p, q) * (args.T**2) loss_data[2] = loss_shannon.data.item() loss = loss + scale[2] * loss_shannon loss_collision = renyi_distill('collision')(p, q) * (args.T**2) loss_data[3] = loss_collision.data.item() loss = loss + scale[3] * loss_collision loss_min = renyi_distill('min')(p, q) * (args.T**2) loss_data[4] = loss_min.data.item() loss = loss + scale[4] * loss_min loss.backward() optimizer.step() batch_time.update(time.time() - end) end = time.time() if idx % args.print_freq == 0: print( 'Epoch[{0}]:[{1:03}/{2:03}] ' 'Time:{batch_time.val:.4f} ' 'Data:{data_time.val:.4f} ' 'Cls:{cls_losses.val:.4f}({cls_losses.avg:.4f}) ' 'Half:{half_losses.val:.4f}({half_losses.avg:.4f}) ' 'ST:{st_losses.val:.4f}({st_losses.avg:.4f}) ' 'Collision:{collision_losses.val:.4f}({collision_losses.avg:.4f}) ' 'Min:{min_losses.val:.4f}({min_losses.avg:.4f}) ' 'prec@1:{top1.val:.2f}({top1.avg:.2f}) ' 'prec@5:{top5.val:.2f}({top5.avg:.2f})'.format( epoch, idx, len(train_loader), batch_time=batch_time, data_time=data_time, cls_losses=cls_losses, half_losses=half_losses, st_losses=st_losses, collision_losses=collision_losses, min_losses=min_losses, top1=top1, top5=top5))