def train_multi_task(param_file):
    with open('configs.json') as config_params:
        configs = json.load(config_params)

    with open(param_file) as json_params:
        params = json.load(json_params)


    exp_identifier = []
    for (key, val) in params.items():
        if 'tasks' in key:
            continue
        exp_identifier+= ['{}={}'.format(key,val)]

    exp_identifier = '|'.join(exp_identifier)
    params['exp_id'] = exp_identifier

    #writer = SummaryWriter(log_dir='runs/{}_{}'.format(params['exp_id'], datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y")))

    train_loader, train_dst, val_loader, val_dst = datasets.get_dataset(params, configs)
    loss_fn = losses.get_loss(params)
    metric = metrics.get_metrics(params)

    model = model_selector.get_model(params)
    model_params = []
    for m in model:
        model_params += model[m].parameters()

    if 'RMSprop' in params['optimizer']:
        optimizer = torch.optim.RMSprop(model_params, lr=params['lr'])
    elif 'Adam' in params['optimizer']:
        optimizer = torch.optim.Adam(model_params, lr=params['lr'])
    elif 'SGD' in params['optimizer']:
        optimizer = torch.optim.SGD(model_params, lr=params['lr'], momentum=0.9)

    tasks = params['tasks']
    all_tasks = configs[params['dataset']]['all_tasks']
    print('Starting training with parameters \n \t{} \n'.format(str(params)))

    if 'mgda' in params['algorithm']:
        approximate_norm_solution = params['use_approximation']
        if approximate_norm_solution:
            print('Using approximate min-norm solver')
        else:
            print('Using full solver')
    n_iter = 0
    loss_init = {}
    for epoch in tqdm(range(NUM_EPOCHS)):
        start = timer()
        print('Epoch {} Started'.format(epoch))
        if (epoch+1) % 10 == 0:
            # Every 50 epoch, half the LR
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.85
            print('Half the learning rate{}'.format(n_iter))

        for m in model:
            model[m].train()

        for batch in train_loader:
            n_iter += 1
            # First member is always images
            images = batch[0]
            images = Variable(images.cuda())

            labels = {}
            # Read all targets of all tasks
            for i, t in enumerate(all_tasks):
                if t not in tasks:
                    continue
                labels[t] = batch[i+1]
                labels[t] = Variable(labels[t].cuda())

            # Scaling the loss functions based on the algorithm choice
            loss_data = {}
            grads = {}
            scale = {}
            mask = None
            masks = {}
            if 'mgda' in params['algorithm']:
                # Will use our MGDA_UB if approximate_norm_solution is True. Otherwise, will use MGDA

                if approximate_norm_solution:
                    optimizer.zero_grad()
                    # First compute representations (z)
                    images_volatile = Variable(images.data, volatile=True)
                    rep, mask = model['rep'](images_volatile, mask)
                    # As an approximate solution we only need gradients for input
                    if isinstance(rep, list):
                        # This is a hack to handle psp-net
                        rep = rep[0]
                        rep_variable = [Variable(rep.data.clone(), requires_grad=True)]
                        list_rep = True
                    else:
                        rep_variable = Variable(rep.data.clone(), requires_grad=True)
                        list_rep = False

                    # Compute gradients of each loss function wrt z
                    for t in tasks:
                        optimizer.zero_grad()
                        out_t, masks[t] = model[t](rep_variable, None)
                        loss = loss_fn[t](out_t, labels[t])
                        loss_data[t] = loss.data[0]
                        loss.backward()
                        grads[t] = []
                        if list_rep:
                            grads[t].append(Variable(rep_variable[0].grad.data.clone(), requires_grad=False))
                            rep_variable[0].grad.data.zero_()
                        else:
                            grads[t].append(Variable(rep_variable.grad.data.clone(), requires_grad=False))
                            rep_variable.grad.data.zero_()
                else:
                    # This is MGDA
                    for t in tasks:
                        # Comptue gradients of each loss function wrt parameters
                        optimizer.zero_grad()
                        rep, mask = model['rep'](images, mask)
                        out_t, masks[t] = model[t](rep, None)
                        loss = loss_fn[t](out_t, labels[t])
                        loss_data[t] = loss.data[0]
                        loss.backward()
                        grads[t] = []
                        for param in model['rep'].parameters():
                            if param.grad is not None:
                                grads[t].append(Variable(param.grad.data.clone(), requires_grad=False))

                # Normalize all gradients, this is optional and not included in the paper.
                gn = gradient_normalizers(grads, loss_data, params['normalization_type'])
                for t in tasks:
                    for gr_i in range(len(grads[t])):
                        grads[t][gr_i] = grads[t][gr_i] / gn[t]

                # Frank-Wolfe iteration to compute scales.                
                sol, min_norm = MinNormSolver.find_min_norm_element([grads[t] for t in tasks])
                for i, t in enumerate(tasks):
                    scale[t] = float(sol[i])
            else:
                for t in tasks:
                    masks[t] = None
                    scale[t] = float(params['scales'][t])

            # Scaled back-propagation
            optimizer.zero_grad()
            rep, _ = model['rep'](images, mask)
            for i, t in enumerate(tasks):
                out_t, _ = model[t](rep, masks[t])
                loss_t = loss_fn[t](out_t, labels[t])
                loss_data[t] = loss_t.data[0]
                if i > 0:
                    loss = loss + scale[t]*loss_t
                else:
                    loss = scale[t]*loss_t
            loss.backward()
            optimizer.step()

            writer.add_scalar('training_loss', loss.data[0], n_iter)
            for t in tasks:
                writer.add_scalar('training_loss_{}'.format(t), loss_data[t], n_iter)

        for m in model:
            model[m].eval()

        tot_loss = {}
        tot_loss['all'] = 0.0
        met = {}
        for t in tasks:
            tot_loss[t] = 0.0
            met[t] = 0.0

        num_val_batches = 0
        for batch_val in val_loader:
            val_images = Variable(batch_val[0].cuda(), volatile=True)
            labels_val = {}

            for i, t in enumerate(all_tasks):
                if t not in tasks:
                    continue
                labels_val[t] = batch_val[i+1]
                labels_val[t] = Variable(labels_val[t].cuda(), volatile=True)

            val_rep, _ = model['rep'](val_images, None)
            for t in tasks:
                out_t_val, _ = model[t](val_rep, None)
                loss_t = loss_fn[t](out_t_val, labels_val[t])
                tot_loss['all'] += loss_t.data[0]
                tot_loss[t] += loss_t.data[0]
                metric[t].update(out_t_val, labels_val[t])
            num_val_batches+=1

        for t in tasks:
            writer.add_scalar('validation_loss_{}'.format(t), tot_loss[t]/num_val_batches, n_iter)
            metric_results = metric[t].get_result()
            for metric_key in metric_results:
                writer.add_scalar('metric_{}_{}'.format(metric_key, t), metric_results[metric_key], n_iter)
            metric[t].reset()
        writer.add_scalar('validation_loss', tot_loss['all']/len(val_dst), n_iter)

        if epoch % 3 == 0:
            # Save after every 3 epoch
            state = {'epoch': epoch+1,
                    'model_rep': model['rep'].state_dict(),
                    'optimizer_state' : optimizer.state_dict()}
            for t in tasks:
                key_name = 'model_{}'.format(t)
                state[key_name] = model[t].state_dict()

            torch.save(state, "saved_models/{}_{}_model.pkl".format(params['exp_id'], epoch+1))

        end = timer()
        print('Epoch ended in {}s'.format(end - start))
예제 #2
0
    def step(self, batch):

        # Scaling the loss functions based on the algorithm choice
        # loss_data = {}
        # grads = {}
        # scale = {}
        # mask = None
        # masks = {}

        # Will use our MGDA_UB if approximate_norm_solution is True. Otherwise, will use MGDA
        if self.approximate_norm_solution:
            self.model.zero_grad()
            # First compute representations (z)
            with torch.no_grad():
                # images_volatile = Variable(images.data, volatile=True)
                # rep, mask = model['rep'](images_volatile, mask)
                rep = self.model.forward_feature_extraction(batch)

            # As an approximate solution we only need gradients for input
            # if isinstance(rep, list):
            #     # This is a hack to handle psp-net
            #     rep = rep[0]
            #     rep_variable = [Variable(rep.data.clone(), requires_grad=True)]
            #     list_rep = True
            # else:
            #     rep_variable = Variable(rep.data.clone(), requires_grad=True)
            #     list_rep = False

            # Compute gradients of each loss function wrt z

            gradients = []
            obj_values = []
            for i, objective in enumerate(self.objectives):
                # zero grad
                self.model.zero_grad()

                logits = self.model.forward_linear(rep, i)
                batch.update(logits)

                output = objective(**batch)
                output.backward()

                obj_values.append(output.item())
                gradients.append({})

                private_params = self.model.private_params() if hasattr(
                    self.model, 'private_params') else []
                for name, param in self.model.named_parameters():
                    not_private = all([p not in name for p in private_params])
                    if not_private and param.requires_grad and param.grad is not None:
                        gradients[i][name] = param.grad.data.detach().clone()
                        param.grad = None
                self.model.zero_grad()

            grads = gradients

            # for t in tasks:
            #     self.model.zero_grad()
            #     out_t, masks[t] = model[t](rep_variable, None)
            #     loss = loss_fn[t](out_t, labels[t])
            #     loss_data[t] = loss.data[0]
            #     loss.backward()
            #     grads[t] = []
            #     if list_rep:
            #         grads[t].append(Variable(rep_variable[0].grad.data.clone(), requires_grad=False))
            #         rep_variable[0].grad.data.zero_()
            #     else:
            #         grads[t].append(Variable(rep_variable.grad.data.clone(), requires_grad=False))
            #         rep_variable.grad.data.zero_()

        else:
            # This is MGDA
            grads, obj_values = calc_gradients(batch, self.model,
                                               self.objectives)

            # for t in tasks:
            #     # Comptue gradients of each loss function wrt parameters
            #     self.model.zero_grad()
            #     rep, mask = model['rep'](images, mask)
            #     out_t, masks[t] = model[t](rep, None)
            #     loss = loss_fn[t](out_t, labels[t])
            #     loss_data[t] = loss.data[0]
            #     loss.backward()
            #     grads[t] = []
            #     for param in self.model['rep'].parameters():
            #         if param.grad is not None:
            #             grads[t].append(Variable(param.grad.data.clone(), requires_grad=False))

        # Normalize all gradients, this is optional and not included in the paper.

        gn = gradient_normalizers(grads, obj_values, self.normalization_type)
        for t in range(len(self.objectives)):
            for gr_i in grads[t]:
                grads[t][gr_i] = grads[t][gr_i] / gn[t]

        # Frank-Wolfe iteration to compute scales.
        grads = [[v for v in d.values()] for d in grads]
        sol, min_norm = MinNormSolver.find_min_norm_element(grads)
        # for i, t in enumerate(range(len(self.objectives))):
        #     scale[t] = float(sol[i])

        # Scaled back-propagation
        self.model.zero_grad()
        logits = self.model(batch)
        batch.update(logits)
        loss_total = None
        for a, objective in zip(sol, self.objectives):
            task_loss = objective(**batch)
            loss_total = a * task_loss if not loss_total else loss_total + a * task_loss

        loss_total.backward()
        return loss_total.item(), 0
예제 #3
0
                abs_err, rel_err = model.depth_error(out_t, label)
            else:
                label = train_normal
                mean, med, map1, map2, map3 = model.normal_error(out_t, label)
            ### end scoring ###
            task_loss = model.model_fit(out_t, label, t)
            task_losses.append(task_loss[0].item())
            loss_data[t] = task_loss.data
            task_loss.backward()
            grads[t] = []
            grads[t].append(
                Variable(rep_variable.grad.data.clone(), requires_grad=False))
            rep_variable.grad.data.zero_()

        # Normalize all gradients, this is optional and not included in the paper.
        gn = gradient_normalizers(grads, loss_data, 'none')
        for t in range(num_tasks):
            for gr_i in range(len(grads[t])):
                grads[t][gr_i] = grads[t][gr_i] / gn[t]

        # Frank-Wolfe iteration to compute scales.
        sol, min_norm = MinNormSolver.find_min_norm_element(
            [grads[t] for t in range(num_tasks)])
        for t in range(num_tasks):
            scale[t] = float(sol[t])

        # Scaled back-propagation
        optimizer.zero_grad()
        feats = model.forward_shared(train_data)
        for t in range(num_tasks):
            out_t = model.forward_task(feats, t)
예제 #4
0
def train_multi_task(params, fold=0):
    with open('configs.json') as config_params:
        configs = json.load(config_params)

    # with open(param_file) as json_params:
    #     params = json.load(json_params)

    exp_identifier = []
    for (key, val) in params.items():
        if ('tasks' in key) or ('dataset' in key) or ('normalization_type' in key) \
            or ('grid_search' in key) or ('train' in key) or ('test' in key):
            continue
        exp_identifier+= ['{}={}'.format(key,val)]

    exp_identifier = '|'.join(exp_identifier)
    # params['exp_id'] = exp_identifier

    if params['train'] :
        train_loader, train_dst, val_loader, val_dst = dataset_selector.get_dataset(params, configs, fold)
        writer = SummaryWriter(log_dir='5fold_runs/{}_{}'.format(exp_identifier, datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y")))

    if params['test'] :
        test_loader, test_dst = dataset_selector.get_dataset(params, configs)

    loss_fn = loss_selector.get_loss(params)
    metric = metrics_selector.get_metrics(params)

    model = model_selector.get_model(params)
    model_params = []
    model_params_num = 0
    for m in model:
        model_params += model[m].parameters()
        for parameter in model[m].parameters():
            # print('parameter:')
            # print(parameter)
            model_params_num += parameter.numel()
    # print('model params num:')
    # print(model_params_num)

    if 'RMSprop' in params['optimizer']:
        optimizer = torch.optim.RMSprop(model_params, lr=params['lr'])
    elif 'Adam' in params['optimizer']:
        optimizer = torch.optim.Adam(model_params, lr=params['lr'])
    elif 'SGD' in params['optimizer']:
        optimizer = torch.optim.SGD(model_params, lr=params['lr'], momentum=0.9)

    tasks = params['tasks']
    all_tasks = configs[params['dataset']]['all_tasks']
    print('Starting training with parameters \n \t{} \n'.format(str(params)))

    n_iter = 0
    loss_init = {}

    # early stopping
    count = 0
    init_val_plcc = 0
    best_epoch = 0

    # train
    if params['train'] :

        for epoch in tqdm(range(NUM_EPOCHS)):
            start = timer()
            print('Epoch {} Started'.format(epoch))
            if (epoch+1) % 30 == 0:
                # Every 30 epoch, half the LR
                for param_group in optimizer.param_groups:
                    param_group['lr'] *= 0.5
                print('Half the learning rate {}'.format(n_iter))

            
            for m in model:
                model[m].train()

            for batch in train_loader:
                n_iter += 1
                
                # First member is always images
                images = batch[0]
                images = Variable(images.cuda())

                labels = {}
                # Read all targets of all tasks
                for i, t in enumerate(all_tasks):
                    if t not in tasks:
                        continue
                    labels[t] = batch[i+1]
                    labels[t] = Variable(labels[t].cuda())
                    
                # Scaling the loss functions based on the algorithm choice
                loss_data = {}
                grads = {}
                scale = {}
                mask = None
                masks = {}

                # use algo MGDA_UB 
                optimizer.zero_grad()
                # First compute representations (z)
                with torch.no_grad():
                    images_volatile = Variable(images.data)
                rep, mask = model['rep'](images_volatile, mask)
                # As an approximate solution we only need gradients for input
                rep_variable = Variable(rep.data.clone(), requires_grad=True)
                list_rep = False

                # Compute gradients of each loss function wrt z
                for t in tasks:
                    optimizer.zero_grad()
                    out_t, masks[t] = model[t](rep_variable, None)
                    loss = loss_fn[t](out_t, labels[t])
                    loss_data[t] = loss.item()
                    loss.backward()
                    grads[t] = Variable(rep_variable.grad.data.clone(), requires_grad=False)
                    rep_variable.grad.data.zero_()

                # Normalize all gradients, this is optional and not included in the paper.
                gn = gradient_normalizers(grads, loss_data, params['normalization_type'])
                for t in tasks:
                    for gr_i in range(len(grads[t])):
                        grads[t][gr_i] = grads[t][gr_i] / gn[t]

                # Frank-Wolfe iteration to compute scales.
                sol, min_norm = MinNormSolver.find_min_norm_element([grads[t] for t in tasks])
                for i, t in enumerate(tasks):
                    scale[t] = float(sol[i])

                # Scaled back-propagation
                optimizer.zero_grad()
                rep, _ = model['rep'](images, mask)
                for i, t in enumerate(tasks):
                    out_t, _ = model[t](rep, masks[t])
                    loss_t = loss_fn[t](out_t, labels[t])
                    loss_data[t] = loss_t.item()
                    if i > 0:
                        loss = loss + scale[t]*loss_t
                    else:
                        loss = scale[t]*loss_t
                loss.backward()
                optimizer.step()

                writer.add_scalar('training_loss', loss.item(), n_iter)
                for t in tasks:
                    writer.add_scalar('training_loss_{}'.format(t), loss_data[t], n_iter)
            
            # validation
            for m in model:
                model[m].eval()

            tot_loss = {}
            tot_loss['all'] = 0.0
            met = {}
            for t in tasks:
                tot_loss[t] = 0.0
                met[t] = 0.0

            num_val_batches = 0
            for batch_val in val_loader:
                with torch.no_grad():
                    val_images = Variable(batch_val[0].cuda())
                labels_val = {}

                for i, t in enumerate(all_tasks):
                    if t not in tasks:
                        continue
                    labels_val[t] = batch_val[i+1]
                    with torch.no_grad():
                        labels_val[t] = Variable(labels_val[t].cuda())

                val_rep, _ = model['rep'](val_images, None)
                for t in tasks:
                    out_t_val, _ = model[t](val_rep, None)
                    loss_t = loss_fn[t](out_t_val, labels_val[t])
                    tot_loss['all'] += loss_t.item()
                    tot_loss[t] += loss_t.item()
                    metric[t].update(out_t_val, labels_val[t])
                num_val_batches+=1

            avg_plcc = 0
            for t in tasks:
                writer.add_scalar('validation_loss_{}'.format(t), tot_loss[t]/num_val_batches, n_iter)
                metric_results = metric[t].get_result()
                avg_plcc += metric_results['plcc']
                metric_str = 'task_{} : '.format(t)
                for metric_key in metric_results:
                    writer.add_scalar('metric_{}_{}'.format(metric_key, t), metric_results[metric_key], n_iter)
                    metric_str += '{} = {}  '.format(metric_key, metric_results[metric_key])
                metric[t].reset()
                metric_str += 'loss = {}'.format(tot_loss[t]/num_val_batches)
                print(metric_str)
            print('all loss = {}'.format(tot_loss['all']/len(val_dst)))
            writer.add_scalar('validation_loss', tot_loss['all']/len(val_dst), n_iter)
            avg_plcc /= 4

            print(avg_plcc)
            print(init_val_plcc)
            if init_val_plcc < avg_plcc:
                init_val_plcc = avg_plcc
                # save model weights if val loss decreases
                print('Saving model...')
                state = {'epoch': epoch+1,
                        'model_rep': model['rep'].state_dict(),
                        'optimizer_state' : optimizer.state_dict()}
                for t in tasks:
                    key_name = 'model_{}'.format(t)
                    state[key_name] = model[t].state_dict()

                torch.save(state, "saved_models/{}_{}_{}_model.pkl".format(exp_identifier, epoch+1, fold))
                best_epoch = epoch + 1
                # reset count
                count = 0
            elif init_val_plcc >= avg_plcc:
                count += 1
                if count == 10:
                    print('Val EMD loss has not decreased in %d epochs. Training terminated.' % 10)
                    break

            end = timer()
            print('Epoch ended in {}s'.format(end - start))

        print('Training completed.')
        return exp_identifier, init_val_plcc, best_epoch

    # test
    if params['test'] :
        state = torch.load(os.path.join('./saved_models', "{}_{}_{}_model.pkl".format(params['exp_identifier'], params['best_epoch'], params['best_fold'])))
        model['rep'].load_state_dict(state['model_rep'])
        for t in tasks :
            key_name = 'model_{}'.format(t)
            model[t].load_state_dict(state[key_name])
        print('Successfully loaded {}_{}_{}_model'.format(params['exp_identifier'], params['best_epoch'], params['best_fold']))


        for m in model:
            model[m].eval()
        
        test_tot_loss = {}
        test_tot_loss['all'] = 0.0
        test_met = {}
        for t in tasks:
            test_tot_loss[t] = 0.0
            test_met[t] = 0.0

        num_test_batches = 0
        for batch_test in test_loader:
            with torch.no_grad():
                test_images = Variable(batch_test[0].cuda())
            labels_test = {}

            for i, t in enumerate(all_tasks):
                if t not in tasks:
                    continue
                labels_test[t] = batch_test[i+1]
                with torch.no_grad():
                    labels_test[t] = Variable(labels_test[t].cuda())

            test_rep, _ = model['rep'](test_images, None)
            for t in tasks:
                out_t_test, _ = model[t](test_rep, None)
                test_loss_t = loss_fn[t](out_t_test, labels_test[t])
                test_tot_loss['all'] += test_loss_t.item()
                test_tot_loss[t] += test_loss_t.item()
                metric[t].update(out_t_test, labels_test[t])
            num_test_batches+=1

        print('test:')
        for t in tasks:
            test_metric_results = metric[t].get_result()
            test_metric_str = 'task_{} : '.format(t)
            for metric_key in test_metric_results:
                test_metric_str += '{} = {}  '.format(metric_key, test_metric_results[metric_key])
            metric[t].reset()
            # test_metric_str += 'loss = {}'.format(test_tot_loss[t]/num_test_batches)
            print(test_metric_str)
예제 #5
0
    def _backward_step(self, input_train, target_train, input_valid,
                       target_valid, eta, network_optimizer):
        grads = {}
        loss_data = {}
        self.optimizer.zero_grad()
        if self.args.unrolled:
            unrolled_model = self._compute_unrolled_model(
                input_train, target_train, eta, network_optimizer)
        else:
            unrolled_model = self.model
        if self.args.adv_outer:
            input_valid = Variable(input_valid.data, requires_grad=True).cuda()

        # ---- acc loss ----
        unrolled_loss = unrolled_model._loss(input_valid, target_valid)
        loss_data['acc'] = unrolled_loss.data[0] / 2  # lossNorm
        grads['acc'] = list(
            torch.autograd.grad(unrolled_loss,
                                unrolled_model.arch_parameters(),
                                retain_graph=True))
        # ---- acc loss end ----

        # ---- adv loss ----
        if self.args.adv_outer and (self.epoch >= self.args.adv_later):
            step_size = self.epsilon * 1.25
            delta = ((torch.rand(input_valid.size()) - 0.5) *
                     2).cuda() * self.epsilon
            adv_grad = torch.autograd.grad(unrolled_loss,
                                           input_valid,
                                           retain_graph=True,
                                           create_graph=False)[0]
            adv_grad = adv_grad.detach().data
            delta = clamp(delta + step_size * torch.sign(adv_grad),
                          -self.epsilon, self.epsilon)
            delta = clamp(delta, self.lower_limit - input_valid.data,
                          self.upper_limit - input_valid.data)
            adv_input = Variable(input_valid.data + delta,
                                 requires_grad=False).cuda()
            self.optimizer.zero_grad()
            unrolled_loss_adv = unrolled_model._loss(adv_input, target_valid)
            grads['adv'] = list(
                torch.autograd.grad(unrolled_loss_adv,
                                    unrolled_model.arch_parameters(),
                                    retain_graph=True))
            loss_data['adv'] = unrolled_loss_adv.data[0] / 2  # lossNorm
        # ---- adv loss end ----

        # ---- param loss ----
        if self.args.nop_outer and (self.epoch >= self.args.nop_later):
            self.optimizer.zero_grad()
            param_loss = self.param_number(unrolled_model)
            loss_data['nop'] = param_loss.data[0]
            grads['nop'] = list(
                torch.autograd.grad(param_loss,
                                    unrolled_model.arch_parameters(),
                                    retain_graph=True))
        # ---- param loss end ----

        # ---- ood loss ----
        if self.args.ood_outer:
            self.optimizer.zero_grad()
            ood_logits = unrolled_model.forward(self.ood_input)
            ood_loss = F.kl_div(input=F.log_softmax(ood_logits),
                                target=torch.ones_like(ood_logits) /
                                ood_logits.size()[-1])
            ood_loss = ood_loss * 50  # lossNorm, 10
            loss_data['ood'] = ood_loss.data[0]
            grads['ood'] = list(
                torch.autograd.grad(ood_loss,
                                    unrolled_model.arch_parameters(),
                                    retain_graph=True))
        # ---- ood loss end ----

        # ---- flops loss ----
        if self.args.flp_outer:
            self.optimizer.zero_grad()
            flp_loss = self.cal_flops(unrolled_model)
            loss_data['flp'] = flp_loss.data[0]
            grads['flp'] = list(
                torch.autograd.grad(flp_loss,
                                    unrolled_model.arch_parameters(),
                                    retain_graph=True))
        # ---- flops loss end ----

        gn = gradient_normalizers(
            grads, loss_data,
            normalization_type=self.args.grad_norm)  # loss+, loss, l2

        for t in grads:
            for gr_i in range(len(grads[t])):
                grads[t][gr_i] = grads[t][gr_i] / (gn[t] + 1e-7)

        # ---- MGDA -----
        if self.args.MGDA and (len(grads) > 1):
            sol, _ = MinNormSolver.find_min_norm_element(
                [grads[t] for t in grads])
            sol = [x + 1e-7 for x in sol]
        else:
            sol = [1] * len(grads)
        # print(sol) # acc, adv, nop

        loss = 0
        for kk, t in enumerate(grads):
            if t == 'acc':
                loss += float(sol[kk]) * unrolled_loss
            elif t == 'adv':
                loss += float(sol[kk]) * unrolled_loss_adv
            elif t == 'nop':
                loss += float(sol[kk]) * param_loss
            elif t == 'ood':
                loss += float(sol[kk]) * ood_loss
            elif t == 'flp':
                loss += float(sol[kk]) * flp_loss
        self.optimizer.zero_grad()
        loss.backward()
        # ---- MGDA end -----

        if self.args.unrolled:
            dalpha = [v.grad for v in unrolled_model.arch_parameters()]
            vector = [v.grad.data for v in unrolled_model.parameters()]
            implicit_grads = self._hessian_vector_product(
                vector, input_train, target_train)

            for g, ig in zip(dalpha, implicit_grads):
                g.data.sub_(eta, ig.data)

            for v, g in zip(self.model.arch_parameters(), dalpha):
                if v.grad is None:
                    v.grad = Variable(g.data)
                else:
                    v.grad.data.copy_(g.data)

        # aa = [[gr.pow(2).sum().data[0] for gr in grads[t]] for t in grads]
        logs = namedtuple("logs", ['sol', 'loss_data'])(sol, loss_data)
        # logs.sol = sol
        # logs.param_loss = param_loss
        print(logs)
        return logs
예제 #6
0
파일: train_st.py 프로젝트: CML00/MoDiv
def train(train_loader, nets, optimizer, criterions, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    cls_losses = AverageMeter()
    half_losses = AverageMeter()
    st_losses = AverageMeter()
    collision_losses = AverageMeter()
    min_losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    at_losses = AverageMeter()

    snet = nets['snet']
    tnet = nets['tnet']

    criterionCls = criterions['criterionCls']
    criterionST = criterions['criterionST']

    snet.train()

    end = time.time()
    for idx, (img, target) in enumerate(train_loader, start=1):
        data_time.update(time.time() - end)

        if args.cuda:
            img = img.cuda()
            target = target.cuda()

        img = Variable(img)
        optimizer.zero_grad()
        with torch.no_grad():
            images_volatile = Variable(img.data)

        _, _, _, _, output_s = snet(img)
        _, _, _, _, output_t = tnet(img)

        if isinstance(output_s, list):
            output_s = output_s[0]
            output_s_variable = [
                Variable(output_s.data.clone(), requires_grad=True)
            ]
            list_rep = True
        else:
            output_s_variable = Variable(output_s.data.clone(),
                                         requires_grad=True)
            list_rep = False

        optimizer.zero_grad()

        target_reshape = target.reshape(-1, 1)
        target_onehot = torch.FloatTensor(output_s.shape[0],
                                          output_s.shape[1]).cuda()
        target_onehot.zero_()
        target_onehot.scatter_(1, target_reshape, 1)
        p = F.softmax(output_s / args.T, dim=1)
        q = F.softmax(output_t / args.T, dim=1)
        loss_data = {}
        grads = {}
        optimizer.zero_grad()
        _, _, _, _, output_s = snet(img)
        p = F.softmax(output_s / args.T, dim=1)
        loss_ce = renyi_distill('shannon')(F.softmax(output_s, dim=1),
                                           target_onehot)
        loss_data[0] = loss_ce.data.item()
        loss_ce.backward()
        grads[0] = []
        for param in snet.parameters():
            if param.grad is not None:
                grads[0].append(
                    Variable(param.grad.data.clone(),
                             requires_grad=False).reshape(-1))

        optimizer.zero_grad()
        _, _, _, _, output_s = snet(img)
        p = F.softmax(output_s / args.T, dim=1)
        loss_half = renyi_distill('half-fixed')(p, q) * (args.T**2)
        loss_data[1] = loss_half.data.item()
        loss_half.backward(retain_graph=True)
        grads[1] = []
        for param in snet.parameters():
            if param.grad is not None:
                grads[1].append(
                    Variable(param.grad.data.clone(),
                             requires_grad=False).reshape(-1))

        optimizer.zero_grad()
        _, _, _, _, output_s = snet(img)
        p = F.softmax(output_s / args.T, dim=1)
        loss_shannon = renyi_distill('shannon')(p, q) * (args.T**2)
        loss_data[2] = loss_shannon.data.item()
        loss_shannon.backward(retain_graph=True)
        grads[2] = []
        for param in snet.parameters():
            if param.grad is not None:
                grads[2].append(
                    Variable(param.grad.data.clone(),
                             requires_grad=False).reshape(-1))

        optimizer.zero_grad()
        _, _, _, _, output_s = snet(img)
        p = F.softmax(output_s / args.T, dim=1)
        loss_collision = renyi_distill('collision')(p, q) * (args.T**2)
        loss_data[3] = loss_collision.data.item()
        loss_collision.backward(retain_graph=True)
        grads[3] = []
        for param in snet.parameters():
            if param.grad is not None:
                grads[3].append(
                    Variable(param.grad.data.clone(),
                             requires_grad=False).reshape(-1))

        optimizer.zero_grad()
        _, _, _, _, output_s = snet(img)
        p = F.softmax(output_s / args.T, dim=1)
        loss_min = renyi_distill('min')(p, q) * (args.T**2)
        loss_data[4] = loss_min.data.item()
        loss_min.backward(retain_graph=True)
        grads[4] = []
        for param in snet.parameters():
            if param.grad is not None:
                grads[4].append(
                    Variable(param.grad.data.clone(),
                             requires_grad=False).reshape(-1))

        gn = gradient_normalizers(grads, loss_data, 'l2')
        for t in range(5):
            for gr_i in range(len(grads[t])):
                grads[t][gr_i] = grads[t][gr_i] / gn[t]
        sol, min = MinNormSolver.find_min_norm_element(
            [grads[t] for t in range(5)])
        scale = {}
        for t in range(5):
            scale[t] = float(sol[t])

        prec1, prec5 = accuracy(output_s, target, topk=(1, 5))
        cls_losses.update(loss_ce.item(), img.size(0))
        half_losses.update(loss_half.item(), img.size(0))
        st_losses.update(loss_shannon.item(), img.size(0))
        collision_losses.update(loss_collision.item(), img.size(0))
        min_losses.update(loss_min.item(), img.size(0))
        top1.update(prec1.item(), img.size(0))
        top5.update(prec5.item(), img.size(0))

        optimizer.zero_grad()
        _, rb1_s, rb2_s, rb3_s, output_s = snet(img)
        loss_ce = renyi_distill('shannon')(F.softmax(output_s, dim=1),
                                           target_onehot)
        loss_data[0] = loss_ce.data.item()
        loss = scale[0] * loss_ce

        loss_half = renyi_distill('half')(p, q) * (args.T**2)
        loss_data[1] = loss_half.data.item()
        loss = loss + scale[1] * loss_half

        loss_shannon = renyi_distill('shannon')(p, q) * (args.T**2)
        loss_data[2] = loss_shannon.data.item()
        loss = loss + scale[2] * loss_shannon

        loss_collision = renyi_distill('collision')(p, q) * (args.T**2)
        loss_data[3] = loss_collision.data.item()
        loss = loss + scale[3] * loss_collision

        loss_min = renyi_distill('min')(p, q) * (args.T**2)
        loss_data[4] = loss_min.data.item()
        loss = loss + scale[4] * loss_min

        loss.backward()
        optimizer.step()

        batch_time.update(time.time() - end)
        end = time.time()

        if idx % args.print_freq == 0:
            print(
                'Epoch[{0}]:[{1:03}/{2:03}] '
                'Time:{batch_time.val:.4f} '
                'Data:{data_time.val:.4f}  '
                'Cls:{cls_losses.val:.4f}({cls_losses.avg:.4f})  '
                'Half:{half_losses.val:.4f}({half_losses.avg:.4f})  '
                'ST:{st_losses.val:.4f}({st_losses.avg:.4f})  '
                'Collision:{collision_losses.val:.4f}({collision_losses.avg:.4f})  '
                'Min:{min_losses.val:.4f}({min_losses.avg:.4f})  '
                'prec@1:{top1.val:.2f}({top1.avg:.2f})  '
                'prec@5:{top5.val:.2f}({top5.avg:.2f})'.format(
                    epoch,
                    idx,
                    len(train_loader),
                    batch_time=batch_time,
                    data_time=data_time,
                    cls_losses=cls_losses,
                    half_losses=half_losses,
                    st_losses=st_losses,
                    collision_losses=collision_losses,
                    min_losses=min_losses,
                    top1=top1,
                    top5=top5))