Beispiel #1
0
def get_d_paretomtl(grads, value, weights, i):
    """ calculate the gradient direction for ParetoMTL """

    # check active constraints
    current_weight = weights[i]
    rest_weights = weights
    w = rest_weights - current_weight

    gx = torch.matmul(w, value / torch.norm(value))
    idx = gx > 0

    # calculate the descent direction
    if torch.sum(idx) <= 0:
        sol, nd = MinNormSolver.find_min_norm_element(
            [[grads[t]] for t in range(len(grads))])
        return torch.tensor(sol).cuda().float()

    vec = torch.cat((grads, torch.matmul(w[idx], grads)))
    sol, nd = MinNormSolver.find_min_norm_element([[vec[t]]
                                                   for t in range(len(vec))])

    weight0 = sol[0] + torch.sum(
        torch.stack([
            sol[j] * w[idx][j - 2, 0]
            for j in torch.arange(2, 2 + torch.sum(idx))
        ]))
    weight1 = sol[1] + torch.sum(
        torch.stack([
            sol[j] * w[idx][j - 2, 1]
            for j in torch.arange(2, 2 + torch.sum(idx))
        ]))
    weight = torch.stack([weight0, weight1])

    return weight
Beispiel #2
0
def get_d_paretomtl(grads, losses, preference_vectors, pref_idx):
    """
    calculate the gradient direction for ParetoMTL 
    
    Args:
        grads: flattened gradients for each task
        losses: values of the losses for each task
        preference_vectors: all preference vectors u
        pref_idx: which index of u we are currently using
    """

    # check active constraints
    current_weight = preference_vectors[pref_idx]
    rest_weights = preference_vectors
    w = rest_weights - current_weight

    gx = torch.matmul(w, losses / torch.norm(losses))
    idx = gx > 0

    # calculate the descent direction
    if torch.sum(idx) <= 0:
        # here there are no active constrains in gx
        sol, nd = MinNormSolver.find_min_norm_element_FW(
            [[grads[t]] for t in range(len(grads))])
        return torch.tensor(sol).cuda().float()
    else:
        # we have active constraints, i.e. we have move too far away from out preference vector
        #print('optim idx', idx)
        vec = torch.cat((grads, torch.matmul(w[idx], grads)))
        sol, nd = MinNormSolver.find_min_norm_element([[vec[t]]
                                                       for t in range(len(vec))
                                                       ])
        sol = torch.Tensor(sol).cuda()

        # FIX: handle more than just 2 objectives
        n = preference_vectors.shape[1]
        weights = []
        for i in range(n):
            weight_i = sol[i] + torch.sum(
                torch.stack([
                    sol[j] * w[idx][j - n, i]
                    for j in torch.arange(n, n + torch.sum(idx))
                ]))
            weights.append(weight_i)
        # weight0 =  sol[0] + torch.sum(torch.stack([sol[j] * w[idx][j - 2 ,0] for j in torch.arange(2, 2 + torch.sum(idx))]))
        # weight1 =  sol[1] + torch.sum(torch.stack([sol[j] * w[idx][j - 2 ,1] for j in torch.arange(2, 2 + torch.sum(idx))]))
        # weight = torch.stack([weight0,weight1])

        weight = torch.stack(weights)

        return weight
Beispiel #3
0
def get_d_paretomtl_init(grads, losses, preference_vectors, pref_idx):
    """ 
    calculate the gradient direction for ParetoMTL initialization 

    Args:
        grads: flattened gradients for each task
        losses: values of the losses for each task
        preference_vectors: all preference vectors u
        pref_idx: which index of u we are currently using
    
    Returns:
        flag: is a feasible initial solution found?
        weight: 
    """

    flag = False
    nobj = losses.shape

    # check active constraints, Equation 7
    current_pref = preference_vectors[pref_idx]  # u_k
    w = preference_vectors - current_pref  # (u_j - u_k) \forall j = 1, ..., K
    gx = torch.matmul(
        w, losses /
        torch.norm(losses))  # In the paper they do not normalize the loss
    idx = gx > 0  # I(\theta), i.e the indexes of the active constraints

    active_constraints = w[idx]  # constrains which are violated, i.e. gx > 0

    # calculate the descent direction
    if torch.sum(idx) <= 0:
        flag = True
        return flag, torch.zeros(nobj)
    if torch.sum(idx) == 1:
        sol = torch.ones(1).cuda().float()
    else:
        # Equation 9
        # w[idx] = set of active constraints, i.e. where the solution is closer to another preference vector than the one desired.
        gx_gradient = torch.matmul(
            active_constraints, grads
        )  # We need to take the derivatives of G_j which is w.dot(grads)
        sol, nd = MinNormSolver.find_min_norm_element(
            [[gx_gradient[t]] for t in range(len(gx_gradient))])
        sol = torch.Tensor(sol).cuda()

    # from MinNormSolver we get the weights (alpha) for each gradient. But we need the weights for the losses?
    weight = torch.matmul(sol, active_constraints)

    return flag, weight
Beispiel #4
0
def get_d_paretomtl_init(grads, value, weights, i):
    """ 
    calculate the gradient direction for ParetoMTL initialization 
    """

    flag = False
    nobj = value.shape

    # check active constraints
    current_weight = weights[i]
    rest_weights = weights
    w = rest_weights - current_weight

    gx = torch.matmul(w, value / torch.norm(value))
    idx = gx > 0

    # calculate the descent direction
    if torch.sum(idx) <= 0:
        flag = True
        return flag, torch.zeros(nobj)
    if torch.sum(idx) == 1:
        sol = torch.ones(1).cuda().float()
    else:
        vec = torch.matmul(w[idx], grads)
        sol, nd = MinNormSolver.find_min_norm_element([[vec[t]]
                                                       for t in range(len(vec))
                                                       ])

    # weight0 =  torch.sum(torch.stack([sol[j] * w[idx][j ,0] for j in torch.arange(0, torch.sum(idx))]))
    # weight1 =  torch.sum(torch.stack([sol[j] * w[idx][j ,1] for j in torch.arange(0, torch.sum(idx))]))
    # weight = torch.stack([weight0,weight1])

    new_weights = []
    for t in range(len(value)):
        new_weights.append(
            torch.sum(
                torch.stack([
                    sol[j] * w[idx][j, t]
                    for j in torch.arange(0, torch.sum(idx))
                ])))

    return flag, torch.stack(new_weights)
def train_multi_task(param_file):
    with open('configs.json') as config_params:
        configs = json.load(config_params)

    with open(param_file) as json_params:
        params = json.load(json_params)


    exp_identifier = []
    for (key, val) in params.items():
        if 'tasks' in key:
            continue
        exp_identifier+= ['{}={}'.format(key,val)]

    exp_identifier = '|'.join(exp_identifier)
    params['exp_id'] = exp_identifier

    #writer = SummaryWriter(log_dir='runs/{}_{}'.format(params['exp_id'], datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y")))

    train_loader, train_dst, val_loader, val_dst = datasets.get_dataset(params, configs)
    loss_fn = losses.get_loss(params)
    metric = metrics.get_metrics(params)

    model = model_selector.get_model(params)
    model_params = []
    for m in model:
        model_params += model[m].parameters()

    if 'RMSprop' in params['optimizer']:
        optimizer = torch.optim.RMSprop(model_params, lr=params['lr'])
    elif 'Adam' in params['optimizer']:
        optimizer = torch.optim.Adam(model_params, lr=params['lr'])
    elif 'SGD' in params['optimizer']:
        optimizer = torch.optim.SGD(model_params, lr=params['lr'], momentum=0.9)

    tasks = params['tasks']
    all_tasks = configs[params['dataset']]['all_tasks']
    print('Starting training with parameters \n \t{} \n'.format(str(params)))

    if 'mgda' in params['algorithm']:
        approximate_norm_solution = params['use_approximation']
        if approximate_norm_solution:
            print('Using approximate min-norm solver')
        else:
            print('Using full solver')
    n_iter = 0
    loss_init = {}
    for epoch in tqdm(range(NUM_EPOCHS)):
        start = timer()
        print('Epoch {} Started'.format(epoch))
        if (epoch+1) % 10 == 0:
            # Every 50 epoch, half the LR
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.85
            print('Half the learning rate{}'.format(n_iter))

        for m in model:
            model[m].train()

        for batch in train_loader:
            n_iter += 1
            # First member is always images
            images = batch[0]
            images = Variable(images.cuda())

            labels = {}
            # Read all targets of all tasks
            for i, t in enumerate(all_tasks):
                if t not in tasks:
                    continue
                labels[t] = batch[i+1]
                labels[t] = Variable(labels[t].cuda())

            # Scaling the loss functions based on the algorithm choice
            loss_data = {}
            grads = {}
            scale = {}
            mask = None
            masks = {}
            if 'mgda' in params['algorithm']:
                # Will use our MGDA_UB if approximate_norm_solution is True. Otherwise, will use MGDA

                if approximate_norm_solution:
                    optimizer.zero_grad()
                    # First compute representations (z)
                    images_volatile = Variable(images.data, volatile=True)
                    rep, mask = model['rep'](images_volatile, mask)
                    # As an approximate solution we only need gradients for input
                    if isinstance(rep, list):
                        # This is a hack to handle psp-net
                        rep = rep[0]
                        rep_variable = [Variable(rep.data.clone(), requires_grad=True)]
                        list_rep = True
                    else:
                        rep_variable = Variable(rep.data.clone(), requires_grad=True)
                        list_rep = False

                    # Compute gradients of each loss function wrt z
                    for t in tasks:
                        optimizer.zero_grad()
                        out_t, masks[t] = model[t](rep_variable, None)
                        loss = loss_fn[t](out_t, labels[t])
                        loss_data[t] = loss.data[0]
                        loss.backward()
                        grads[t] = []
                        if list_rep:
                            grads[t].append(Variable(rep_variable[0].grad.data.clone(), requires_grad=False))
                            rep_variable[0].grad.data.zero_()
                        else:
                            grads[t].append(Variable(rep_variable.grad.data.clone(), requires_grad=False))
                            rep_variable.grad.data.zero_()
                else:
                    # This is MGDA
                    for t in tasks:
                        # Comptue gradients of each loss function wrt parameters
                        optimizer.zero_grad()
                        rep, mask = model['rep'](images, mask)
                        out_t, masks[t] = model[t](rep, None)
                        loss = loss_fn[t](out_t, labels[t])
                        loss_data[t] = loss.data[0]
                        loss.backward()
                        grads[t] = []
                        for param in model['rep'].parameters():
                            if param.grad is not None:
                                grads[t].append(Variable(param.grad.data.clone(), requires_grad=False))

                # Normalize all gradients, this is optional and not included in the paper.
                gn = gradient_normalizers(grads, loss_data, params['normalization_type'])
                for t in tasks:
                    for gr_i in range(len(grads[t])):
                        grads[t][gr_i] = grads[t][gr_i] / gn[t]

                # Frank-Wolfe iteration to compute scales.                
                sol, min_norm = MinNormSolver.find_min_norm_element([grads[t] for t in tasks])
                for i, t in enumerate(tasks):
                    scale[t] = float(sol[i])
            else:
                for t in tasks:
                    masks[t] = None
                    scale[t] = float(params['scales'][t])

            # Scaled back-propagation
            optimizer.zero_grad()
            rep, _ = model['rep'](images, mask)
            for i, t in enumerate(tasks):
                out_t, _ = model[t](rep, masks[t])
                loss_t = loss_fn[t](out_t, labels[t])
                loss_data[t] = loss_t.data[0]
                if i > 0:
                    loss = loss + scale[t]*loss_t
                else:
                    loss = scale[t]*loss_t
            loss.backward()
            optimizer.step()

            writer.add_scalar('training_loss', loss.data[0], n_iter)
            for t in tasks:
                writer.add_scalar('training_loss_{}'.format(t), loss_data[t], n_iter)

        for m in model:
            model[m].eval()

        tot_loss = {}
        tot_loss['all'] = 0.0
        met = {}
        for t in tasks:
            tot_loss[t] = 0.0
            met[t] = 0.0

        num_val_batches = 0
        for batch_val in val_loader:
            val_images = Variable(batch_val[0].cuda(), volatile=True)
            labels_val = {}

            for i, t in enumerate(all_tasks):
                if t not in tasks:
                    continue
                labels_val[t] = batch_val[i+1]
                labels_val[t] = Variable(labels_val[t].cuda(), volatile=True)

            val_rep, _ = model['rep'](val_images, None)
            for t in tasks:
                out_t_val, _ = model[t](val_rep, None)
                loss_t = loss_fn[t](out_t_val, labels_val[t])
                tot_loss['all'] += loss_t.data[0]
                tot_loss[t] += loss_t.data[0]
                metric[t].update(out_t_val, labels_val[t])
            num_val_batches+=1

        for t in tasks:
            writer.add_scalar('validation_loss_{}'.format(t), tot_loss[t]/num_val_batches, n_iter)
            metric_results = metric[t].get_result()
            for metric_key in metric_results:
                writer.add_scalar('metric_{}_{}'.format(metric_key, t), metric_results[metric_key], n_iter)
            metric[t].reset()
        writer.add_scalar('validation_loss', tot_loss['all']/len(val_dst), n_iter)

        if epoch % 3 == 0:
            # Save after every 3 epoch
            state = {'epoch': epoch+1,
                    'model_rep': model['rep'].state_dict(),
                    'optimizer_state' : optimizer.state_dict()}
            for t in tasks:
                key_name = 'model_{}'.format(t)
                state[key_name] = model[t].state_dict()

            torch.save(state, "saved_models/{}_{}_model.pkl".format(params['exp_id'], epoch+1))

        end = timer()
        print('Epoch ended in {}s'.format(end - start))
Beispiel #6
0
    def step(self, batch):

        # Scaling the loss functions based on the algorithm choice
        # loss_data = {}
        # grads = {}
        # scale = {}
        # mask = None
        # masks = {}

        # Will use our MGDA_UB if approximate_norm_solution is True. Otherwise, will use MGDA
        if self.approximate_norm_solution:
            self.model.zero_grad()
            # First compute representations (z)
            with torch.no_grad():
                # images_volatile = Variable(images.data, volatile=True)
                # rep, mask = model['rep'](images_volatile, mask)
                rep = self.model.forward_feature_extraction(batch)

            # As an approximate solution we only need gradients for input
            # if isinstance(rep, list):
            #     # This is a hack to handle psp-net
            #     rep = rep[0]
            #     rep_variable = [Variable(rep.data.clone(), requires_grad=True)]
            #     list_rep = True
            # else:
            #     rep_variable = Variable(rep.data.clone(), requires_grad=True)
            #     list_rep = False

            # Compute gradients of each loss function wrt z

            gradients = []
            obj_values = []
            for i, objective in enumerate(self.objectives):
                # zero grad
                self.model.zero_grad()

                logits = self.model.forward_linear(rep, i)
                batch.update(logits)

                output = objective(**batch)
                output.backward()

                obj_values.append(output.item())
                gradients.append({})

                private_params = self.model.private_params() if hasattr(
                    self.model, 'private_params') else []
                for name, param in self.model.named_parameters():
                    not_private = all([p not in name for p in private_params])
                    if not_private and param.requires_grad and param.grad is not None:
                        gradients[i][name] = param.grad.data.detach().clone()
                        param.grad = None
                self.model.zero_grad()

            grads = gradients

            # for t in tasks:
            #     self.model.zero_grad()
            #     out_t, masks[t] = model[t](rep_variable, None)
            #     loss = loss_fn[t](out_t, labels[t])
            #     loss_data[t] = loss.data[0]
            #     loss.backward()
            #     grads[t] = []
            #     if list_rep:
            #         grads[t].append(Variable(rep_variable[0].grad.data.clone(), requires_grad=False))
            #         rep_variable[0].grad.data.zero_()
            #     else:
            #         grads[t].append(Variable(rep_variable.grad.data.clone(), requires_grad=False))
            #         rep_variable.grad.data.zero_()

        else:
            # This is MGDA
            grads, obj_values = calc_gradients(batch, self.model,
                                               self.objectives)

            # for t in tasks:
            #     # Comptue gradients of each loss function wrt parameters
            #     self.model.zero_grad()
            #     rep, mask = model['rep'](images, mask)
            #     out_t, masks[t] = model[t](rep, None)
            #     loss = loss_fn[t](out_t, labels[t])
            #     loss_data[t] = loss.data[0]
            #     loss.backward()
            #     grads[t] = []
            #     for param in self.model['rep'].parameters():
            #         if param.grad is not None:
            #             grads[t].append(Variable(param.grad.data.clone(), requires_grad=False))

        # Normalize all gradients, this is optional and not included in the paper.

        gn = gradient_normalizers(grads, obj_values, self.normalization_type)
        for t in range(len(self.objectives)):
            for gr_i in grads[t]:
                grads[t][gr_i] = grads[t][gr_i] / gn[t]

        # Frank-Wolfe iteration to compute scales.
        grads = [[v for v in d.values()] for d in grads]
        sol, min_norm = MinNormSolver.find_min_norm_element(grads)
        # for i, t in enumerate(range(len(self.objectives))):
        #     scale[t] = float(sol[i])

        # Scaled back-propagation
        self.model.zero_grad()
        logits = self.model(batch)
        batch.update(logits)
        loss_total = None
        for a, objective in zip(sol, self.objectives):
            task_loss = objective(**batch)
            loss_total = a * task_loss if not loss_total else loss_total + a * task_loss

        loss_total.backward()
        return loss_total.item(), 0
Beispiel #7
0
def get_d_mgda(vec):
    r"""Calculate the gradient direction for MGDA."""
    sol, nd = MinNormSolver.find_min_norm_element([[vec[t]]
                                                   for t in range(len(vec))])
    return torch.tensor(sol).cuda().float()
Beispiel #8
0
            task_losses.append(task_loss[0].item())
            loss_data[t] = task_loss.data
            task_loss.backward()
            grads[t] = []
            grads[t].append(
                Variable(rep_variable.grad.data.clone(), requires_grad=False))
            rep_variable.grad.data.zero_()

        # Normalize all gradients, this is optional and not included in the paper.
        gn = gradient_normalizers(grads, loss_data, 'none')
        for t in range(num_tasks):
            for gr_i in range(len(grads[t])):
                grads[t][gr_i] = grads[t][gr_i] / gn[t]

        # Frank-Wolfe iteration to compute scales.
        sol, min_norm = MinNormSolver.find_min_norm_element(
            [grads[t] for t in range(num_tasks)])
        for t in range(num_tasks):
            scale[t] = float(sol[t])

        # Scaled back-propagation
        optimizer.zero_grad()
        feats = model.forward_shared(train_data)
        for t in range(num_tasks):
            out_t = model.forward_task(feats, t)
            if t < 13:
                label = train_label[:, t, None, :, :]
            elif t == 13:
                label = train_depth
            else:
                label = train_normal
            loss_t = model.model_fit(out_t, label, t)
def compute_loss_coeff(optimizers, device, future_frames_num,
                       all_disp_field_gt, all_valid_pixel_maps,
                       pixel_cat_map_gt, disp_pred, criterion, non_empty_map,
                       class_pred, motion_gt, motion_pred,
                       shared_feats_tensor):
    encoder_optimizer = optimizers[0]
    head_optimizer = optimizers[1]
    encoder_optimizer.zero_grad()
    head_optimizer.zero_grad()

    grads = {}

    # Compute the displacement loss
    gt = all_disp_field_gt[:, -future_frames_num:, ...].contiguous()
    gt = gt.view(-1, gt.size(2), gt.size(3), gt.size(4))
    gt = gt.permute(0, 3, 1, 2).to(device)

    valid_pixel_maps = all_valid_pixel_maps[:, -future_frames_num:,
                                            ...].contiguous()
    valid_pixel_maps = valid_pixel_maps.view(-1, valid_pixel_maps.size(2),
                                             valid_pixel_maps.size(3))
    valid_pixel_maps = torch.unsqueeze(valid_pixel_maps, 1)
    valid_pixel_maps = valid_pixel_maps.to(device)

    valid_pixel_num = torch.nonzero(valid_pixel_maps).size(0)
    if valid_pixel_num == 0:
        return [None] * 3

    # ---------------------------------------------------------------------
    # -- Generate the displacement w.r.t. the keyframe
    if pred_adj_frame_distance:
        disp_pred = disp_pred.view(-1, future_frames_num, disp_pred.size(-3),
                                   disp_pred.size(-2), disp_pred.size(-1))
        for c in range(1, disp_pred.size(1)):
            disp_pred[:, c,
                      ...] = disp_pred[:, c, ...] + disp_pred[:, c - 1, ...]
        disp_pred = disp_pred.view(-1, disp_pred.size(-3), disp_pred.size(-2),
                                   disp_pred.size(-1))

    # ---------------------------------------------------------------------
    # -- Compute the masked displacement loss
    # Note: have also tried focal loss, but did not observe noticeable improvement
    pixel_cat_map_gt_numpy = pixel_cat_map_gt.numpy()
    pixel_cat_map_gt_numpy = np.argmax(pixel_cat_map_gt_numpy, axis=-1) + 1
    cat_weight_map = np.zeros_like(pixel_cat_map_gt_numpy, dtype=np.float32)
    weight_vector = [0.005, 1.0, 1.0, 1.0,
                     1.0]  # [bg, car & bus, ped, bike, other]
    for k in range(5):
        mask = pixel_cat_map_gt_numpy == (k + 1)
        cat_weight_map[mask] = weight_vector[k]

    cat_weight_map = cat_weight_map[:, np.newaxis, np.newaxis,
                                    ...]  # (batch, 1, 1, h, w)
    cat_weight_map = torch.from_numpy(cat_weight_map).to(device)
    map_shape = cat_weight_map.size()

    loss_disp = criterion(gt * valid_pixel_maps, disp_pred * valid_pixel_maps)
    loss_disp = loss_disp.view(map_shape[0], -1, map_shape[-3], map_shape[-2],
                               map_shape[-1])
    loss_disp = torch.sum(loss_disp * cat_weight_map) / valid_pixel_num

    encoder_optimizer.zero_grad()
    head_optimizer.zero_grad()

    loss_disp.backward(retain_graph=True)
    grads[0] = []
    grads[0].append(shared_feats_tensor.grad.data.clone().detach())
    shared_feats_tensor.grad.data.zero_()

    # ---------------------------------------------------------------------
    # -- Compute the grid cell classification loss
    non_empty_map = non_empty_map.view(-1, 256, 256)
    non_empty_map = non_empty_map.to(device)
    pixel_cat_map_gt = pixel_cat_map_gt.permute(0, 3, 1, 2).to(device)

    log_softmax_probs = F.log_softmax(class_pred, dim=1)

    map_shape = cat_weight_map.size()
    cat_weight_map = cat_weight_map.view(map_shape[0], map_shape[-2],
                                         map_shape[-1])  # (bs, h, w)
    loss_class = torch.sum(-pixel_cat_map_gt * log_softmax_probs,
                           dim=1) * cat_weight_map
    loss_class = torch.sum(
        loss_class * non_empty_map) / torch.nonzero(non_empty_map).size(0)

    encoder_optimizer.zero_grad()
    head_optimizer.zero_grad()

    loss_class.backward(retain_graph=True)
    grads[1] = []
    grads[1].append(shared_feats_tensor.grad.data.clone().detach())
    shared_feats_tensor.grad.data.zero_()

    # ---------------------------------------------------------------------
    # -- Compute the speed level classification loss
    motion_gt_numpy = motion_gt.numpy()
    motion_gt = motion_gt.permute(0, 3, 1, 2).to(device)
    log_softmax_motion_pred = F.log_softmax(motion_pred, dim=1)

    motion_gt_numpy = np.argmax(motion_gt_numpy, axis=-1) + 1
    motion_weight_map = np.zeros_like(motion_gt_numpy, dtype=np.float32)
    weight_vector = [0.005, 1.0]
    for k in range(2):
        mask = motion_gt_numpy == (k + 1)
        motion_weight_map[mask] = weight_vector[k]

    motion_weight_map = torch.from_numpy(motion_weight_map).to(device)
    loss_speed = torch.sum(-motion_gt * log_softmax_motion_pred,
                           dim=1) * motion_weight_map
    loss_motion = torch.sum(
        loss_speed * non_empty_map) / torch.nonzero(non_empty_map).size(0)

    encoder_optimizer.zero_grad()
    head_optimizer.zero_grad()

    loss_motion.backward(retain_graph=True)
    grads[2] = []
    grads[2].append(shared_feats_tensor.grad.data.clone().detach())
    shared_feats_tensor.grad.data.zero_()

    # ---------------------------------------------------------------------
    # -- Frank-Wolfe iteration to compute scales.
    scale = np.zeros(3, dtype=np.float32)
    sol, min_norm = MinNormSolver.find_min_norm_element(
        [grads[t] for t in range(3)])
    for i in range(3):
        scale[i] = float(sol[i])

    return scale
def train_multi_task(params, fold=0):
    with open('configs.json') as config_params:
        configs = json.load(config_params)

    # with open(param_file) as json_params:
    #     params = json.load(json_params)

    exp_identifier = []
    for (key, val) in params.items():
        if ('tasks' in key) or ('dataset' in key) or ('normalization_type' in key) \
            or ('grid_search' in key) or ('train' in key) or ('test' in key):
            continue
        exp_identifier+= ['{}={}'.format(key,val)]

    exp_identifier = '|'.join(exp_identifier)
    # params['exp_id'] = exp_identifier

    if params['train'] :
        train_loader, train_dst, val_loader, val_dst = dataset_selector.get_dataset(params, configs, fold)
        writer = SummaryWriter(log_dir='5fold_runs/{}_{}'.format(exp_identifier, datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y")))

    if params['test'] :
        test_loader, test_dst = dataset_selector.get_dataset(params, configs)

    loss_fn = loss_selector.get_loss(params)
    metric = metrics_selector.get_metrics(params)

    model = model_selector.get_model(params)
    model_params = []
    model_params_num = 0
    for m in model:
        model_params += model[m].parameters()
        for parameter in model[m].parameters():
            # print('parameter:')
            # print(parameter)
            model_params_num += parameter.numel()
    # print('model params num:')
    # print(model_params_num)

    if 'RMSprop' in params['optimizer']:
        optimizer = torch.optim.RMSprop(model_params, lr=params['lr'])
    elif 'Adam' in params['optimizer']:
        optimizer = torch.optim.Adam(model_params, lr=params['lr'])
    elif 'SGD' in params['optimizer']:
        optimizer = torch.optim.SGD(model_params, lr=params['lr'], momentum=0.9)

    tasks = params['tasks']
    all_tasks = configs[params['dataset']]['all_tasks']
    print('Starting training with parameters \n \t{} \n'.format(str(params)))

    n_iter = 0
    loss_init = {}

    # early stopping
    count = 0
    init_val_plcc = 0
    best_epoch = 0

    # train
    if params['train'] :

        for epoch in tqdm(range(NUM_EPOCHS)):
            start = timer()
            print('Epoch {} Started'.format(epoch))
            if (epoch+1) % 30 == 0:
                # Every 30 epoch, half the LR
                for param_group in optimizer.param_groups:
                    param_group['lr'] *= 0.5
                print('Half the learning rate {}'.format(n_iter))

            
            for m in model:
                model[m].train()

            for batch in train_loader:
                n_iter += 1
                
                # First member is always images
                images = batch[0]
                images = Variable(images.cuda())

                labels = {}
                # Read all targets of all tasks
                for i, t in enumerate(all_tasks):
                    if t not in tasks:
                        continue
                    labels[t] = batch[i+1]
                    labels[t] = Variable(labels[t].cuda())
                    
                # Scaling the loss functions based on the algorithm choice
                loss_data = {}
                grads = {}
                scale = {}
                mask = None
                masks = {}

                # use algo MGDA_UB 
                optimizer.zero_grad()
                # First compute representations (z)
                with torch.no_grad():
                    images_volatile = Variable(images.data)
                rep, mask = model['rep'](images_volatile, mask)
                # As an approximate solution we only need gradients for input
                rep_variable = Variable(rep.data.clone(), requires_grad=True)
                list_rep = False

                # Compute gradients of each loss function wrt z
                for t in tasks:
                    optimizer.zero_grad()
                    out_t, masks[t] = model[t](rep_variable, None)
                    loss = loss_fn[t](out_t, labels[t])
                    loss_data[t] = loss.item()
                    loss.backward()
                    grads[t] = Variable(rep_variable.grad.data.clone(), requires_grad=False)
                    rep_variable.grad.data.zero_()

                # Normalize all gradients, this is optional and not included in the paper.
                gn = gradient_normalizers(grads, loss_data, params['normalization_type'])
                for t in tasks:
                    for gr_i in range(len(grads[t])):
                        grads[t][gr_i] = grads[t][gr_i] / gn[t]

                # Frank-Wolfe iteration to compute scales.
                sol, min_norm = MinNormSolver.find_min_norm_element([grads[t] for t in tasks])
                for i, t in enumerate(tasks):
                    scale[t] = float(sol[i])

                # Scaled back-propagation
                optimizer.zero_grad()
                rep, _ = model['rep'](images, mask)
                for i, t in enumerate(tasks):
                    out_t, _ = model[t](rep, masks[t])
                    loss_t = loss_fn[t](out_t, labels[t])
                    loss_data[t] = loss_t.item()
                    if i > 0:
                        loss = loss + scale[t]*loss_t
                    else:
                        loss = scale[t]*loss_t
                loss.backward()
                optimizer.step()

                writer.add_scalar('training_loss', loss.item(), n_iter)
                for t in tasks:
                    writer.add_scalar('training_loss_{}'.format(t), loss_data[t], n_iter)
            
            # validation
            for m in model:
                model[m].eval()

            tot_loss = {}
            tot_loss['all'] = 0.0
            met = {}
            for t in tasks:
                tot_loss[t] = 0.0
                met[t] = 0.0

            num_val_batches = 0
            for batch_val in val_loader:
                with torch.no_grad():
                    val_images = Variable(batch_val[0].cuda())
                labels_val = {}

                for i, t in enumerate(all_tasks):
                    if t not in tasks:
                        continue
                    labels_val[t] = batch_val[i+1]
                    with torch.no_grad():
                        labels_val[t] = Variable(labels_val[t].cuda())

                val_rep, _ = model['rep'](val_images, None)
                for t in tasks:
                    out_t_val, _ = model[t](val_rep, None)
                    loss_t = loss_fn[t](out_t_val, labels_val[t])
                    tot_loss['all'] += loss_t.item()
                    tot_loss[t] += loss_t.item()
                    metric[t].update(out_t_val, labels_val[t])
                num_val_batches+=1

            avg_plcc = 0
            for t in tasks:
                writer.add_scalar('validation_loss_{}'.format(t), tot_loss[t]/num_val_batches, n_iter)
                metric_results = metric[t].get_result()
                avg_plcc += metric_results['plcc']
                metric_str = 'task_{} : '.format(t)
                for metric_key in metric_results:
                    writer.add_scalar('metric_{}_{}'.format(metric_key, t), metric_results[metric_key], n_iter)
                    metric_str += '{} = {}  '.format(metric_key, metric_results[metric_key])
                metric[t].reset()
                metric_str += 'loss = {}'.format(tot_loss[t]/num_val_batches)
                print(metric_str)
            print('all loss = {}'.format(tot_loss['all']/len(val_dst)))
            writer.add_scalar('validation_loss', tot_loss['all']/len(val_dst), n_iter)
            avg_plcc /= 4

            print(avg_plcc)
            print(init_val_plcc)
            if init_val_plcc < avg_plcc:
                init_val_plcc = avg_plcc
                # save model weights if val loss decreases
                print('Saving model...')
                state = {'epoch': epoch+1,
                        'model_rep': model['rep'].state_dict(),
                        'optimizer_state' : optimizer.state_dict()}
                for t in tasks:
                    key_name = 'model_{}'.format(t)
                    state[key_name] = model[t].state_dict()

                torch.save(state, "saved_models/{}_{}_{}_model.pkl".format(exp_identifier, epoch+1, fold))
                best_epoch = epoch + 1
                # reset count
                count = 0
            elif init_val_plcc >= avg_plcc:
                count += 1
                if count == 10:
                    print('Val EMD loss has not decreased in %d epochs. Training terminated.' % 10)
                    break

            end = timer()
            print('Epoch ended in {}s'.format(end - start))

        print('Training completed.')
        return exp_identifier, init_val_plcc, best_epoch

    # test
    if params['test'] :
        state = torch.load(os.path.join('./saved_models', "{}_{}_{}_model.pkl".format(params['exp_identifier'], params['best_epoch'], params['best_fold'])))
        model['rep'].load_state_dict(state['model_rep'])
        for t in tasks :
            key_name = 'model_{}'.format(t)
            model[t].load_state_dict(state[key_name])
        print('Successfully loaded {}_{}_{}_model'.format(params['exp_identifier'], params['best_epoch'], params['best_fold']))


        for m in model:
            model[m].eval()
        
        test_tot_loss = {}
        test_tot_loss['all'] = 0.0
        test_met = {}
        for t in tasks:
            test_tot_loss[t] = 0.0
            test_met[t] = 0.0

        num_test_batches = 0
        for batch_test in test_loader:
            with torch.no_grad():
                test_images = Variable(batch_test[0].cuda())
            labels_test = {}

            for i, t in enumerate(all_tasks):
                if t not in tasks:
                    continue
                labels_test[t] = batch_test[i+1]
                with torch.no_grad():
                    labels_test[t] = Variable(labels_test[t].cuda())

            test_rep, _ = model['rep'](test_images, None)
            for t in tasks:
                out_t_test, _ = model[t](test_rep, None)
                test_loss_t = loss_fn[t](out_t_test, labels_test[t])
                test_tot_loss['all'] += test_loss_t.item()
                test_tot_loss[t] += test_loss_t.item()
                metric[t].update(out_t_test, labels_test[t])
            num_test_batches+=1

        print('test:')
        for t in tasks:
            test_metric_results = metric[t].get_result()
            test_metric_str = 'task_{} : '.format(t)
            for metric_key in test_metric_results:
                test_metric_str += '{} = {}  '.format(metric_key, test_metric_results[metric_key])
            metric[t].reset()
            # test_metric_str += 'loss = {}'.format(test_tot_loss[t]/num_test_batches)
            print(test_metric_str)
Beispiel #11
0
    def _backward_step(self, input_train, target_train, input_valid,
                       target_valid, eta, network_optimizer):
        grads = {}
        loss_data = {}
        self.optimizer.zero_grad()
        if self.args.unrolled:
            unrolled_model = self._compute_unrolled_model(
                input_train, target_train, eta, network_optimizer)
        else:
            unrolled_model = self.model
        if self.args.adv_outer:
            input_valid = Variable(input_valid.data, requires_grad=True).cuda()

        # ---- acc loss ----
        unrolled_loss = unrolled_model._loss(input_valid, target_valid)
        loss_data['acc'] = unrolled_loss.data[0] / 2  # lossNorm
        grads['acc'] = list(
            torch.autograd.grad(unrolled_loss,
                                unrolled_model.arch_parameters(),
                                retain_graph=True))
        # ---- acc loss end ----

        # ---- adv loss ----
        if self.args.adv_outer and (self.epoch >= self.args.adv_later):
            step_size = self.epsilon * 1.25
            delta = ((torch.rand(input_valid.size()) - 0.5) *
                     2).cuda() * self.epsilon
            adv_grad = torch.autograd.grad(unrolled_loss,
                                           input_valid,
                                           retain_graph=True,
                                           create_graph=False)[0]
            adv_grad = adv_grad.detach().data
            delta = clamp(delta + step_size * torch.sign(adv_grad),
                          -self.epsilon, self.epsilon)
            delta = clamp(delta, self.lower_limit - input_valid.data,
                          self.upper_limit - input_valid.data)
            adv_input = Variable(input_valid.data + delta,
                                 requires_grad=False).cuda()
            self.optimizer.zero_grad()
            unrolled_loss_adv = unrolled_model._loss(adv_input, target_valid)
            grads['adv'] = list(
                torch.autograd.grad(unrolled_loss_adv,
                                    unrolled_model.arch_parameters(),
                                    retain_graph=True))
            loss_data['adv'] = unrolled_loss_adv.data[0] / 2  # lossNorm
        # ---- adv loss end ----

        # ---- param loss ----
        if self.args.nop_outer and (self.epoch >= self.args.nop_later):
            self.optimizer.zero_grad()
            param_loss = self.param_number(unrolled_model)
            loss_data['nop'] = param_loss.data[0]
            grads['nop'] = list(
                torch.autograd.grad(param_loss,
                                    unrolled_model.arch_parameters(),
                                    retain_graph=True))
        # ---- param loss end ----

        # ---- ood loss ----
        if self.args.ood_outer:
            self.optimizer.zero_grad()
            ood_logits = unrolled_model.forward(self.ood_input)
            ood_loss = F.kl_div(input=F.log_softmax(ood_logits),
                                target=torch.ones_like(ood_logits) /
                                ood_logits.size()[-1])
            ood_loss = ood_loss * 50  # lossNorm, 10
            loss_data['ood'] = ood_loss.data[0]
            grads['ood'] = list(
                torch.autograd.grad(ood_loss,
                                    unrolled_model.arch_parameters(),
                                    retain_graph=True))
        # ---- ood loss end ----

        # ---- flops loss ----
        if self.args.flp_outer:
            self.optimizer.zero_grad()
            flp_loss = self.cal_flops(unrolled_model)
            loss_data['flp'] = flp_loss.data[0]
            grads['flp'] = list(
                torch.autograd.grad(flp_loss,
                                    unrolled_model.arch_parameters(),
                                    retain_graph=True))
        # ---- flops loss end ----

        gn = gradient_normalizers(
            grads, loss_data,
            normalization_type=self.args.grad_norm)  # loss+, loss, l2

        for t in grads:
            for gr_i in range(len(grads[t])):
                grads[t][gr_i] = grads[t][gr_i] / (gn[t] + 1e-7)

        # ---- MGDA -----
        if self.args.MGDA and (len(grads) > 1):
            sol, _ = MinNormSolver.find_min_norm_element(
                [grads[t] for t in grads])
            sol = [x + 1e-7 for x in sol]
        else:
            sol = [1] * len(grads)
        # print(sol) # acc, adv, nop

        loss = 0
        for kk, t in enumerate(grads):
            if t == 'acc':
                loss += float(sol[kk]) * unrolled_loss
            elif t == 'adv':
                loss += float(sol[kk]) * unrolled_loss_adv
            elif t == 'nop':
                loss += float(sol[kk]) * param_loss
            elif t == 'ood':
                loss += float(sol[kk]) * ood_loss
            elif t == 'flp':
                loss += float(sol[kk]) * flp_loss
        self.optimizer.zero_grad()
        loss.backward()
        # ---- MGDA end -----

        if self.args.unrolled:
            dalpha = [v.grad for v in unrolled_model.arch_parameters()]
            vector = [v.grad.data for v in unrolled_model.parameters()]
            implicit_grads = self._hessian_vector_product(
                vector, input_train, target_train)

            for g, ig in zip(dalpha, implicit_grads):
                g.data.sub_(eta, ig.data)

            for v, g in zip(self.model.arch_parameters(), dalpha):
                if v.grad is None:
                    v.grad = Variable(g.data)
                else:
                    v.grad.data.copy_(g.data)

        # aa = [[gr.pow(2).sum().data[0] for gr in grads[t]] for t in grads]
        logs = namedtuple("logs", ['sol', 'loss_data'])(sol, loss_data)
        # logs.sol = sol
        # logs.param_loss = param_loss
        print(logs)
        return logs
Beispiel #12
0
def train(train_loader, nets, optimizer, criterions, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    cls_losses = AverageMeter()
    half_losses = AverageMeter()
    st_losses = AverageMeter()
    collision_losses = AverageMeter()
    min_losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    at_losses = AverageMeter()

    snet = nets['snet']
    tnet = nets['tnet']

    criterionCls = criterions['criterionCls']
    criterionST = criterions['criterionST']

    snet.train()

    end = time.time()
    for idx, (img, target) in enumerate(train_loader, start=1):
        data_time.update(time.time() - end)

        if args.cuda:
            img = img.cuda()
            target = target.cuda()

        img = Variable(img)
        optimizer.zero_grad()
        with torch.no_grad():
            images_volatile = Variable(img.data)

        _, _, _, _, output_s = snet(img)
        _, _, _, _, output_t = tnet(img)

        if isinstance(output_s, list):
            output_s = output_s[0]
            output_s_variable = [
                Variable(output_s.data.clone(), requires_grad=True)
            ]
            list_rep = True
        else:
            output_s_variable = Variable(output_s.data.clone(),
                                         requires_grad=True)
            list_rep = False

        optimizer.zero_grad()

        target_reshape = target.reshape(-1, 1)
        target_onehot = torch.FloatTensor(output_s.shape[0],
                                          output_s.shape[1]).cuda()
        target_onehot.zero_()
        target_onehot.scatter_(1, target_reshape, 1)
        p = F.softmax(output_s / args.T, dim=1)
        q = F.softmax(output_t / args.T, dim=1)
        loss_data = {}
        grads = {}
        optimizer.zero_grad()
        _, _, _, _, output_s = snet(img)
        p = F.softmax(output_s / args.T, dim=1)
        loss_ce = renyi_distill('shannon')(F.softmax(output_s, dim=1),
                                           target_onehot)
        loss_data[0] = loss_ce.data.item()
        loss_ce.backward()
        grads[0] = []
        for param in snet.parameters():
            if param.grad is not None:
                grads[0].append(
                    Variable(param.grad.data.clone(),
                             requires_grad=False).reshape(-1))

        optimizer.zero_grad()
        _, _, _, _, output_s = snet(img)
        p = F.softmax(output_s / args.T, dim=1)
        loss_half = renyi_distill('half-fixed')(p, q) * (args.T**2)
        loss_data[1] = loss_half.data.item()
        loss_half.backward(retain_graph=True)
        grads[1] = []
        for param in snet.parameters():
            if param.grad is not None:
                grads[1].append(
                    Variable(param.grad.data.clone(),
                             requires_grad=False).reshape(-1))

        optimizer.zero_grad()
        _, _, _, _, output_s = snet(img)
        p = F.softmax(output_s / args.T, dim=1)
        loss_shannon = renyi_distill('shannon')(p, q) * (args.T**2)
        loss_data[2] = loss_shannon.data.item()
        loss_shannon.backward(retain_graph=True)
        grads[2] = []
        for param in snet.parameters():
            if param.grad is not None:
                grads[2].append(
                    Variable(param.grad.data.clone(),
                             requires_grad=False).reshape(-1))

        optimizer.zero_grad()
        _, _, _, _, output_s = snet(img)
        p = F.softmax(output_s / args.T, dim=1)
        loss_collision = renyi_distill('collision')(p, q) * (args.T**2)
        loss_data[3] = loss_collision.data.item()
        loss_collision.backward(retain_graph=True)
        grads[3] = []
        for param in snet.parameters():
            if param.grad is not None:
                grads[3].append(
                    Variable(param.grad.data.clone(),
                             requires_grad=False).reshape(-1))

        optimizer.zero_grad()
        _, _, _, _, output_s = snet(img)
        p = F.softmax(output_s / args.T, dim=1)
        loss_min = renyi_distill('min')(p, q) * (args.T**2)
        loss_data[4] = loss_min.data.item()
        loss_min.backward(retain_graph=True)
        grads[4] = []
        for param in snet.parameters():
            if param.grad is not None:
                grads[4].append(
                    Variable(param.grad.data.clone(),
                             requires_grad=False).reshape(-1))

        gn = gradient_normalizers(grads, loss_data, 'l2')
        for t in range(5):
            for gr_i in range(len(grads[t])):
                grads[t][gr_i] = grads[t][gr_i] / gn[t]
        sol, min = MinNormSolver.find_min_norm_element(
            [grads[t] for t in range(5)])
        scale = {}
        for t in range(5):
            scale[t] = float(sol[t])

        prec1, prec5 = accuracy(output_s, target, topk=(1, 5))
        cls_losses.update(loss_ce.item(), img.size(0))
        half_losses.update(loss_half.item(), img.size(0))
        st_losses.update(loss_shannon.item(), img.size(0))
        collision_losses.update(loss_collision.item(), img.size(0))
        min_losses.update(loss_min.item(), img.size(0))
        top1.update(prec1.item(), img.size(0))
        top5.update(prec5.item(), img.size(0))

        optimizer.zero_grad()
        _, rb1_s, rb2_s, rb3_s, output_s = snet(img)
        loss_ce = renyi_distill('shannon')(F.softmax(output_s, dim=1),
                                           target_onehot)
        loss_data[0] = loss_ce.data.item()
        loss = scale[0] * loss_ce

        loss_half = renyi_distill('half')(p, q) * (args.T**2)
        loss_data[1] = loss_half.data.item()
        loss = loss + scale[1] * loss_half

        loss_shannon = renyi_distill('shannon')(p, q) * (args.T**2)
        loss_data[2] = loss_shannon.data.item()
        loss = loss + scale[2] * loss_shannon

        loss_collision = renyi_distill('collision')(p, q) * (args.T**2)
        loss_data[3] = loss_collision.data.item()
        loss = loss + scale[3] * loss_collision

        loss_min = renyi_distill('min')(p, q) * (args.T**2)
        loss_data[4] = loss_min.data.item()
        loss = loss + scale[4] * loss_min

        loss.backward()
        optimizer.step()

        batch_time.update(time.time() - end)
        end = time.time()

        if idx % args.print_freq == 0:
            print(
                'Epoch[{0}]:[{1:03}/{2:03}] '
                'Time:{batch_time.val:.4f} '
                'Data:{data_time.val:.4f}  '
                'Cls:{cls_losses.val:.4f}({cls_losses.avg:.4f})  '
                'Half:{half_losses.val:.4f}({half_losses.avg:.4f})  '
                'ST:{st_losses.val:.4f}({st_losses.avg:.4f})  '
                'Collision:{collision_losses.val:.4f}({collision_losses.avg:.4f})  '
                'Min:{min_losses.val:.4f}({min_losses.avg:.4f})  '
                'prec@1:{top1.val:.2f}({top1.avg:.2f})  '
                'prec@5:{top5.val:.2f}({top5.avg:.2f})'.format(
                    epoch,
                    idx,
                    len(train_loader),
                    batch_time=batch_time,
                    data_time=data_time,
                    cls_losses=cls_losses,
                    half_losses=half_losses,
                    st_losses=st_losses,
                    collision_losses=collision_losses,
                    min_losses=min_losses,
                    top1=top1,
                    top5=top5))