Esempio n. 1
0
class Model():
    def __init__(self, configuration, pre_embed=None):
        configuration = deepcopy(configuration)
        self.configuration = deepcopy(configuration)

        configuration['model']['encoder']['pre_embed'] = pre_embed
        self.encoder = Encoder.from_params(
            Params(configuration['model']['encoder'])).to(device)

        configuration['model']['decoder'][
            'hidden_size'] = self.encoder.output_size
        self.decoder = AttnDecoder.from_params(
            Params(configuration['model']['decoder'])).to(device)

        self.encoder_params = list(self.encoder.parameters())
        self.attn_params = list([
            v for k, v in self.decoder.named_parameters() if 'attention' in k
        ])
        self.decoder_params = list([
            v for k, v in self.decoder.named_parameters()
            if 'attention' not in k
        ])

        self.bsize = configuration['training']['bsize']

        weight_decay = configuration['training'].get('weight_decay', 1e-5)
        self.encoder_optim = torch.optim.Adam(self.encoder_params,
                                              lr=0.001,
                                              weight_decay=weight_decay,
                                              amsgrad=True)
        self.attn_optim = torch.optim.Adam(self.attn_params,
                                           lr=0.001,
                                           weight_decay=0,
                                           amsgrad=True)
        self.decoder_optim = torch.optim.Adam(self.decoder_params,
                                              lr=0.001,
                                              weight_decay=weight_decay,
                                              amsgrad=True)
        self.adversarymulti = AdversaryMulti(decoder=self.decoder)

        self.all_params = self.encoder_params + self.attn_params + self.decoder_params
        self.all_optim = torch.optim.Adam(self.all_params,
                                          lr=0.001,
                                          weight_decay=weight_decay,
                                          amsgrad=True)
        # self.all_optim = adagrad.Adagrad(self.all_params, weight_decay=weight_decay)

        pos_weight = configuration['training'].get('pos_weight', [1.0] *
                                                   self.decoder.output_size)
        self.pos_weight = torch.Tensor(pos_weight).to(device)
        self.criterion = nn.BCEWithLogitsLoss(reduction='none').to(device)
        self.swa_settings = configuration['training']['swa']

        import time
        dirname = configuration['training']['exp_dirname']
        basepath = configuration['training'].get('basepath', 'outputs')
        self.time_str = time.ctime().replace(' ', '_')
        self.dirname = os.path.join(basepath, dirname, self.time_str)

        self.temperature = configuration['training']['temperature']
        self.train_losses = []

        if self.swa_settings[0]:
            # self.attn_optim = SWA(self.attn_optim, swa_start=3, swa_freq=1, swa_lr=0.05)
            # self.decoder_optim = SWA(self.decoder_optim, swa_start=3, swa_freq=1, swa_lr=0.05)
            # self.encoder_optim = SWA(self.encoder_optim, swa_start=3, swa_freq=1, swa_lr=0.05)
            self.swa_all_optim = SWA(self.all_optim)
            self.running_norms = []

    @classmethod
    def init_from_config(cls, dirname, **kwargs):
        config = json.load(open(dirname + '/config.json', 'r'))
        config.update(kwargs)
        obj = cls(config)
        obj.load_values(dirname)
        return obj

    def get_param_buffer_norms(self):
        for p in self.swa_all_optim.param_groups[0]['params']:
            param_state = self.swa_all_optim.state[p]
            if 'swa_buffer' not in param_state:
                self.swa_all_optim.update_swa()

        norms = []
        # for p in np.array(self.swa_all_optim.param_groups[0]['params'])[[1, 2, 5, 6, 9]]:
        for p in np.array(self.swa_all_optim.param_groups[0]['params'])[[6,
                                                                         9]]:
            param_state = self.swa_all_optim.state[p]
            buf = np.squeeze(param_state['swa_buffer'].cpu().numpy())
            cur_state = np.squeeze(p.data.cpu().numpy())
            norm = np.linalg.norm(buf - cur_state)
            norms.append(norm)
        if self.swa_settings[3] == 2:
            return np.max(norms)
        return np.mean(norms)

    def total_iter_num(self):
        return self.swa_all_optim.param_groups[0]['step_counter']

    def iter_for_swa_update(self, iter_num):
        return iter_num > self.swa_settings[1] \
               and iter_num % self.swa_settings[2] == 0

    def check_and_update_swa(self):
        if self.iter_for_swa_update(self.total_iter_num()):
            cur_step_diff_norm = self.get_param_buffer_norms()
            if self.swa_settings[3] == 0:
                self.swa_all_optim.update_swa()
                return
            if not self.running_norms:
                running_mean_norm = 0
            else:
                running_mean_norm = np.mean(self.running_norms)

            if cur_step_diff_norm > running_mean_norm:
                self.swa_all_optim.update_swa()
                self.running_norms = [cur_step_diff_norm]
            elif cur_step_diff_norm > 0:
                self.running_norms.append(cur_step_diff_norm)

    def train(self, data_in, target_in, train=True):
        sorting_idx = get_sorting_index_with_noise_from_lengths(
            [len(x) for x in data_in], noise_frac=0.1)
        data = [data_in[i] for i in sorting_idx]
        target = [target_in[i] for i in sorting_idx]

        self.encoder.train()
        self.decoder.train()
        bsize = self.bsize
        N = len(data)
        loss_total = 0

        batches = list(range(0, N, bsize))
        batches = shuffle(batches)

        for n in tqdm(batches):
            torch.cuda.empty_cache()
            batch_doc = data[n:n + bsize]
            batch_data = BatchHolder(batch_doc)

            self.encoder(batch_data)
            self.decoder(batch_data)

            batch_target = target[n:n + bsize]
            batch_target = torch.Tensor(batch_target).to(device)

            if len(batch_target.shape) == 1:  #(B, )
                batch_target = batch_target.unsqueeze(-1)  #(B, 1)

            bce_loss = self.criterion(batch_data.predict / self.temperature,
                                      batch_target)
            weight = batch_target * self.pos_weight + (1 - batch_target)
            bce_loss = (bce_loss * weight).mean(1).sum()

            loss = bce_loss
            self.train_losses.append(bce_loss.detach().cpu().numpy() + 0)

            if hasattr(batch_data, 'reg_loss'):
                loss += batch_data.reg_loss

            if train:
                if self.swa_settings[0]:
                    self.check_and_update_swa()

                    self.swa_all_optim.zero_grad()
                    loss.backward()
                    self.swa_all_optim.step()

                else:
                    # self.encoder_optim.zero_grad()
                    # self.decoder_optim.zero_grad()
                    # self.attn_optim.zero_grad()
                    self.all_optim.zero_grad()
                    loss.backward()
                    # self.encoder_optim.step()
                    # self.decoder_optim.step()
                    # self.attn_optim.step()
                    self.all_optim.step()

            loss_total += float(loss.data.cpu().item())
        if self.swa_settings[0] and self.swa_all_optim.param_groups[0][
                'step_counter'] > self.swa_settings[1]:
            print("\nSWA swapping\n")
            # self.attn_optim.swap_swa_sgd()
            # self.encoder_optim.swap_swa_sgd()
            # self.decoder_optim.swap_swa_sgd()
            self.swa_all_optim.swap_swa_sgd()
            self.running_norms = []

        return loss_total * bsize / N

    def predictor(self, inp_text_permutations):

        text_permutations = [
            dataset_vec.map2idxs(x.split()) for x in inp_text_permutations
        ]
        outputs = []
        bsize = 512
        N = len(text_permutations)
        for n in range(0, N, bsize):
            torch.cuda.empty_cache()
            batch_doc = text_permutations[n:n + bsize]
            batch_data = BatchHolder(batch_doc)

            self.encoder(batch_data)
            self.decoder(batch_data)

            batch_data.predict = torch.sigmoid(batch_data.predict)

            pred = batch_data.predict.cpu().data.numpy()
            for i in range(len(pred)):
                if math.isnan(pred[i][0]):
                    pred[i][0] = 0.5
            outputs.extend(pred)

        ret_val = [[output_i[0], 1 - output_i[0]] for output_i in outputs]
        ret_val = np.array(ret_val)

        return ret_val

    def evaluate(self, data):
        self.encoder.eval()
        self.decoder.eval()
        bsize = self.bsize
        N = len(data)

        outputs = []
        attns = []

        for n in tqdm(range(0, N, bsize)):
            torch.cuda.empty_cache()
            batch_doc = data[n:n + bsize]
            batch_data = BatchHolder(batch_doc)

            self.encoder(batch_data)
            self.decoder(batch_data)

            batch_data.predict = torch.sigmoid(batch_data.predict /
                                               self.temperature)
            if self.decoder.use_attention:
                attn = batch_data.attn.cpu().data.numpy()
                attns.append(attn)

            predict = batch_data.predict.cpu().data.numpy()
            outputs.append(predict)

        outputs = [x for y in outputs for x in y]
        if self.decoder.use_attention:
            attns = [x for y in attns for x in y]
        return outputs, attns

    def get_lime_explanations(self, data):
        explanations = []
        explainer = LimeTextExplainer(class_names=["A", "B"])
        for data_i in data:
            sentence = ' '.join(dataset_vec.map2words(data_i))
            exp = explainer.explain_instance(text_instance=sentence,
                                             classifier_fn=self.predictor,
                                             num_features=len(data_i),
                                             num_samples=5000).as_list()
            explanations.append(exp)
        return explanations

    def gradient_mem(self, data):
        self.encoder.train()
        self.decoder.train()
        bsize = self.bsize
        N = len(data)

        grads = {'XxE': [], 'XxE[X]': [], 'H': []}

        for n in tqdm(range(0, N, bsize)):
            torch.cuda.empty_cache()
            batch_doc = data[n:n + bsize]

            grads_xxe = []
            grads_xxex = []
            grads_H = []

            for i in range(self.decoder.output_size):
                batch_data = BatchHolder(batch_doc)
                batch_data.keep_grads = True
                batch_data.detach = True

                self.encoder(batch_data)
                self.decoder(batch_data)

                torch.sigmoid(batch_data.predict[:, i]).sum().backward()
                g = batch_data.embedding.grad
                em = batch_data.embedding
                g1 = (g * em).sum(-1)

                grads_xxex.append(g1.cpu().data.numpy())

                g1 = (g * self.encoder.embedding.weight.sum(0)).sum(-1)
                grads_xxe.append(g1.cpu().data.numpy())

                g1 = batch_data.hidden.grad.sum(-1)
                grads_H.append(g1.cpu().data.numpy())

            grads_xxe = np.array(grads_xxe).swapaxes(0, 1)
            grads_xxex = np.array(grads_xxex).swapaxes(0, 1)
            grads_H = np.array(grads_H).swapaxes(0, 1)

            import ipdb
            ipdb.set_trace()
            grads['XxE'].append(grads_xxe)
            grads['XxE[X]'].append(grads_xxex)
            grads['H'].append(grads_H)

        for k in grads:
            grads[k] = [x for y in grads[k] for x in y]

        return grads

    def remove_and_run(self, data):
        self.encoder.train()
        self.decoder.train()
        bsize = self.bsize
        N = len(data)

        outputs = []

        for n in tqdm(range(0, N, bsize)):
            batch_doc = data[n:n + bsize]
            batch_data = BatchHolder(batch_doc)
            po = np.zeros(
                (batch_data.B, batch_data.maxlen, self.decoder.output_size))

            for i in range(1, batch_data.maxlen - 1):
                batch_data = BatchHolder(batch_doc)

                batch_data.seq = torch.cat(
                    [batch_data.seq[:, :i], batch_data.seq[:, i + 1:]], dim=-1)
                batch_data.lengths = batch_data.lengths - 1
                batch_data.masks = torch.cat(
                    [batch_data.masks[:, :i], batch_data.masks[:, i + 1:]],
                    dim=-1)

                self.encoder(batch_data)
                self.decoder(batch_data)

                po[:, i] = torch.sigmoid(batch_data.predict).cpu().data.numpy()

            outputs.append(po)

        outputs = [x for y in outputs for x in y]

        return outputs

    def permute_attn(self, data, num_perm=100):
        self.encoder.train()
        self.decoder.train()
        bsize = self.bsize
        N = len(data)

        permutations = []

        for n in tqdm(range(0, N, bsize)):
            torch.cuda.empty_cache()
            batch_doc = data[n:n + bsize]
            batch_data = BatchHolder(batch_doc)

            batch_perms = np.zeros(
                (batch_data.B, num_perm, self.decoder.output_size))

            self.encoder(batch_data)
            self.decoder(batch_data)

            for i in range(num_perm):
                batch_data.permute = True
                self.decoder(batch_data)
                output = torch.sigmoid(batch_data.predict)
                batch_perms[:, i] = output.cpu().data.numpy()

            permutations.append(batch_perms)

        permutations = [x for y in permutations for x in y]

        return permutations

    def save_values(self,
                    use_dirname=None,
                    save_model=True,
                    append_to_dir_name=''):
        if use_dirname is not None:
            dirname = use_dirname
        else:
            dirname = self.dirname + append_to_dir_name
            self.last_epch_dirname = dirname
        os.makedirs(dirname, exist_ok=True)
        shutil.copy2(file_name, dirname + '/')
        json.dump(self.configuration, open(dirname + '/config.json', 'w'))

        if save_model:
            torch.save(self.encoder.state_dict(), dirname + '/enc.th')
            torch.save(self.decoder.state_dict(), dirname + '/dec.th')

        return dirname

    def load_values(self, dirname):
        self.encoder.load_state_dict(
            torch.load(dirname + '/enc.th', map_location={'cuda:1': 'cuda:0'}))
        self.decoder.load_state_dict(
            torch.load(dirname + '/dec.th', map_location={'cuda:1': 'cuda:0'}))

    def adversarial_multi(self, data):
        self.encoder.eval()
        self.decoder.eval()

        for p in self.encoder.parameters():
            p.requires_grad = False

        for p in self.decoder.parameters():
            p.requires_grad = False

        bsize = self.bsize
        N = len(data)

        adverse_attn = []
        adverse_output = []

        for n in tqdm(range(0, N, bsize)):
            torch.cuda.empty_cache()
            batch_doc = data[n:n + bsize]
            batch_data = BatchHolder(batch_doc)

            self.encoder(batch_data)
            self.decoder(batch_data)

            self.adversarymulti(batch_data)

            attn_volatile = batch_data.attn_volatile.cpu().data.numpy(
            )  #(B, 10, L)
            predict_volatile = batch_data.predict_volatile.cpu().data.numpy(
            )  #(B, 10, O)

            adverse_attn.append(attn_volatile)
            adverse_output.append(predict_volatile)

        adverse_output = [x for y in adverse_output for x in y]
        adverse_attn = [x for y in adverse_attn for x in y]

        return adverse_output, adverse_attn

    def logodds_attention(self, data, logodds_map: Dict):
        self.encoder.eval()
        self.decoder.eval()

        bsize = self.bsize
        N = len(data)

        adverse_attn = []
        adverse_output = []

        logodds = np.zeros((self.encoder.vocab_size, ))
        for k, v in logodds_map.items():
            if v is not None:
                logodds[k] = abs(v)
            else:
                logodds[k] = float('-inf')
        logodds = torch.Tensor(logodds).to(device)

        for n in tqdm(range(0, N, bsize)):
            torch.cuda.empty_cache()
            batch_doc = data[n:n + bsize]
            batch_data = BatchHolder(batch_doc)

            self.encoder(batch_data)
            self.decoder(batch_data)

            attn = batch_data.attn  #(B, L)
            batch_data.attn_logodds = logodds[batch_data.seq]
            self.decoder.get_output_from_logodds(batch_data)

            attn_volatile = batch_data.attn_volatile.cpu().data.numpy(
            )  #(B, L)
            predict_volatile = torch.sigmoid(
                batch_data.predict_volatile).cpu().data.numpy()  #(B, O)

            adverse_attn.append(attn_volatile)
            adverse_output.append(predict_volatile)

        adverse_output = [x for y in adverse_output for x in y]
        adverse_attn = [x for y in adverse_attn for x in y]

        return adverse_output, adverse_attn

    def logodds_substitution(self, data, top_logodds_words: Dict):
        self.encoder.eval()
        self.decoder.eval()

        bsize = self.bsize
        N = len(data)

        adverse_X = []
        adverse_attn = []
        adverse_output = []

        words_neg = torch.Tensor(
            top_logodds_words[0][0]).long().cuda().unsqueeze(0)
        words_pos = torch.Tensor(
            top_logodds_words[0][1]).long().cuda().unsqueeze(0)

        words_to_select = torch.cat([words_neg, words_pos], dim=0)  #(2, 5)

        for n in tqdm(range(0, N, bsize)):
            torch.cuda.empty_cache()
            batch_doc = data[n:n + bsize]
            batch_data = BatchHolder(batch_doc)

            self.encoder(batch_data)
            self.decoder(batch_data)
            predict_class = (torch.sigmoid(batch_data.predict).squeeze(-1) >
                             0.5) * 1  #(B,)

            attn = batch_data.attn  #(B, L)
            top_val, top_idx = torch.topk(attn, 5, dim=-1)
            subs_words = words_to_select[1 - predict_class.long()]  #(B, 5)

            batch_data.seq.scatter_(1, top_idx, subs_words)

            self.encoder(batch_data)
            self.decoder(batch_data)

            attn_volatile = batch_data.attn.cpu().data.numpy()  #(B, L)
            predict_volatile = torch.sigmoid(
                batch_data.predict).cpu().data.numpy()  #(B, O)
            X_volatile = batch_data.seq.cpu().data.numpy()

            adverse_X.append(X_volatile)
            adverse_attn.append(attn_volatile)
            adverse_output.append(predict_volatile)

        adverse_X = [x for y in adverse_X for x in y]
        adverse_output = [x for y in adverse_output for x in y]
        adverse_attn = [x for y in adverse_attn for x in y]

        return adverse_output, adverse_attn, adverse_X

    def predict(self, batch_data, lengths, masks):
        batch_holder = BatchHolderIndentity(batch_data, lengths, masks)
        self.encoder(batch_holder)
        self.decoder(batch_holder)
        # batch_holder.predict = torch.sigmoid(batch_holder.predict)
        predict = batch_holder.predict
        return predict
Esempio n. 2
0
    def train(self):
        # prepare data
        train_data = self.data('train')
        train_steps = int((len(train_data) + self.config.batch_size - 1) /
                          self.config.batch_size)
        train_dataloader = DataLoader(train_data,
                                      batch_size=self.config.batch_size,
                                      collate_fn=self.get_collate_fn('train'),
                                      shuffle=True,
                                      num_workers=2)

        # prepare optimizer
        params_lr = [{
            "params": self.model.bert_parameters,
            'lr': self.config.small_lr
        }, {
            "params": self.model.other_parameters,
            'lr': self.config.large_lr
        }]
        optimizer = torch.optim.Adam(params_lr)
        optimizer = SWA(optimizer)

        # prepare early stopping
        early_stopping = EarlyStopping(self.model,
                                       self.config.best_model_path,
                                       big_server=BIG_GPU,
                                       mode='max',
                                       patience=10,
                                       verbose=True)

        # prepare learning schedual
        learning_schedual = LearningSchedual(
            optimizer, self.config.epochs, train_steps,
            [self.config.small_lr, self.config.large_lr])

        # prepare other
        aux = REModelAux(self.config, train_steps)
        moving_log = MovingData(window=500)

        ending_flag = False
        # self.model.load_state_dict(torch.load(ROOT_SAVED_MODEL + 'temp_model.ckpt'))
        #
        # with torch.no_grad():
        #     self.model.eval()
        #     print(self.eval())
        #     return
        for epoch in range(0, self.config.epochs):
            for step, (inputs, y_trues,
                       spo_info) in enumerate(train_dataloader):
                inputs = [aaa.cuda() for aaa in inputs]
                y_trues = [aaa.cuda() for aaa in y_trues]
                if epoch > 0 or step == 1000:
                    self.model.detach_bert = False
                # train ================================================================================================
                preds = self.model(inputs)
                loss = self.calculate_loss(preds, y_trues, inputs[1],
                                           inputs[2])
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm(self.model.parameters(), 1)
                optimizer.step()

                with torch.no_grad():

                    logs = {'lr0': 0, 'lr1': 0}
                    if (epoch > 0 or step > 620) and False:
                        sbj_f1, spo_f1 = self.calculate_train_f1(
                            spo_info[0], preds, spo_info[1:3],
                            inputs[2].cpu().numpy())
                        metrics_data = {
                            'loss': loss.cpu().numpy(),
                            'sampled_num': 1,
                            'sbj_correct_num': sbj_f1[0],
                            'sbj_pred_num': sbj_f1[1],
                            'sbj_true_num': sbj_f1[2],
                            'spo_correct_num': spo_f1[0],
                            'spo_pred_num': spo_f1[1],
                            'spo_true_num': spo_f1[2]
                        }
                        moving_data = moving_log(epoch * train_steps + step,
                                                 metrics_data)
                        logs['loss'] = moving_data['loss'] / moving_data[
                            'sampled_num']
                        logs['sbj_precise'], logs['sbj_recall'], logs[
                            'sbj_f1'] = calculate_f1(
                                moving_data['sbj_correct_num'],
                                moving_data['sbj_pred_num'],
                                moving_data['sbj_true_num'],
                                verbose=True)
                        logs['spo_precise'], logs['spo_recall'], logs[
                            'spo_f1'] = calculate_f1(
                                moving_data['spo_correct_num'],
                                moving_data['spo_pred_num'],
                                moving_data['spo_true_num'],
                                verbose=True)
                    else:
                        metrics_data = {
                            'loss': loss.cpu().numpy(),
                            'sampled_num': 1
                        }
                        moving_data = moving_log(epoch * train_steps + step,
                                                 metrics_data)
                        logs['loss'] = moving_data['loss'] / moving_data[
                            'sampled_num']

                    # update lr
                    logs['lr0'], logs['lr1'] = learning_schedual.update_lr(
                        epoch, step)

                    if step == int(train_steps / 2) or step + 1 == train_steps:
                        self.model.eval()
                        torch.save(self.model.state_dict(),
                                   ROOT_SAVED_MODEL + 'temp_model.ckpt')
                        aux.new_line()
                        # dev ==========================================================================================
                        dev_result = self.eval()
                        logs['dev_loss'] = dev_result['loss']
                        logs['dev_sbj_precise'] = dev_result['sbj_precise']
                        logs['dev_sbj_recall'] = dev_result['sbj_recall']
                        logs['dev_sbj_f1'] = dev_result['sbj_f1']
                        logs['dev_spo_precise'] = dev_result['spo_precise']
                        logs['dev_spo_recall'] = dev_result['spo_recall']
                        logs['dev_spo_f1'] = dev_result['spo_f1']
                        logs['dev_precise'] = dev_result['precise']
                        logs['dev_recall'] = dev_result['recall']
                        logs['dev_f1'] = dev_result['f1']

                        # other thing
                        early_stopping(logs['dev_f1'])
                        if logs['dev_f1'] > 0.730:
                            optimizer.update_swa()

                        # test =========================================================================================
                        if (epoch + 1 == self.config.epochs and step + 1
                                == train_steps) or early_stopping.early_stop:
                            ending_flag = True
                            optimizer.swap_swa_sgd()
                            optimizer.bn_update(train_dataloader, self.model)
                            torch.save(self.model.state_dict(),
                                       ROOT_SAVED_MODEL + 'swa.ckpt')
                            self.test(ROOT_SAVED_MODEL + 'swa.ckpt')

                        self.model.train()
                aux.show_log(epoch, step, logs)
                if ending_flag:
                    return
Esempio n. 3
0
def train(model,
          device,
          trainloader,
          testloader,
          optimizer,
          criterion,
          metric,
          epochs,
          learning_rate,
          swa=True,
          enable_scheduler=True,
          model_arch=''):
    '''
    Function to perform model training.
    '''
    model.to(device)
    steps = 0
    running_loss = 0
    running_metric = 0
    print_every = 100

    train_losses = []
    test_losses = []
    train_metrics = []
    test_metrics = []

    if swa:
        # initialize stochastic weight averaging
        opt = SWA(optimizer)
    else:
        opt = optimizer

    # learning rate cosine annealing
    if enable_scheduler:
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer,
                                                   len(trainloader),
                                                   eta_min=0.0000001)

    for epoch in range(epochs):

        if enable_scheduler:
            scheduler.step()

        for inputs, labels in trainloader:

            steps += 1
            # Move input and label tensors to the default device
            inputs, labels = inputs.to(device), labels.to(device)

            opt.zero_grad()

            outputs = model.forward(inputs)
            loss = criterion(outputs, labels.float())
            loss.backward()
            opt.step()

            running_loss += loss
            running_metric += metric(outputs, labels.float())

            if steps % print_every == 0:
                test_loss = 0
                test_metric = 0
                model.eval()
                with torch.no_grad():
                    for inputs, labels in testloader:
                        inputs, labels = inputs.to(device), labels.to(device)
                        outputs = model.forward(inputs)

                        test_loss += criterion(outputs, labels.float())

                        test_metric += metric(outputs, labels.float())

                print(f"Epoch {epoch+1}/{epochs}.. "
                      f"Train loss: {running_loss/print_every:.3f}.. "
                      f"Test loss: {test_loss/len(testloader):.3f}.. "
                      f"Train metric: {running_metric/print_every:.3f}.. "
                      f"Test metric: {test_metric/len(testloader):.3f}.. ")

                train_losses.append(running_loss / print_every)
                test_losses.append(test_loss / len(testloader))
                train_metrics.append(running_metric / print_every)
                test_metrics.append(test_metric / len(testloader))

                running_loss = 0
                running_metric = 0

                model.train()
                if swa:
                    opt.update_swa()

        save_model(model,
                   model_arch,
                   learning_rate,
                   epochs,
                   train_losses,
                   test_losses,
                   train_metrics,
                   test_metrics,
                   filepath='models_checkpoints')

    if swa:
        opt.swap_swa_sgd()

    return model, train_losses, test_losses, train_metrics, test_metrics
Esempio n. 4
0
class Train(object):
    """Train class.
  """
    def __init__(self, train_ds, val_ds, fold):
        self.fold = fold

        self.init_lr = cfg.TRAIN.init_lr
        self.warup_step = cfg.TRAIN.warmup_step
        self.epochs = cfg.TRAIN.epoch
        self.batch_size = cfg.TRAIN.batch_size
        self.l2_regularization = cfg.TRAIN.weight_decay_factor

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else 'cpu')

        self.model = Net().to(self.device)

        self.load_weight()

        param_optimizer = list(self.model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']

        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            cfg.TRAIN.weight_decay_factor
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]

        if 'Adamw' in cfg.TRAIN.opt:

            self.optimizer = torch.optim.AdamW(self.model.parameters(),
                                               lr=self.init_lr,
                                               eps=1.e-5)
        else:
            self.optimizer = torch.optim.SGD(self.model.parameters(),
                                             lr=0.001,
                                             momentum=0.9)

        if cfg.TRAIN.SWA > 0:
            ##use swa
            self.optimizer = SWA(self.optimizer)

        if cfg.TRAIN.mix_precision:
            self.model, self.optimizer = amp.initialize(self.model,
                                                        self.optimizer,
                                                        opt_level="O1")

        self.ema = EMA(self.model, 0.999)

        self.ema.register()
        ###control vars
        self.iter_num = 0

        self.train_ds = train_ds

        self.val_ds = val_ds

        # self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,mode='max', patience=3,verbose=True)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, self.epochs, eta_min=1.e-6)

        self.criterion = nn.BCEWithLogitsLoss().to(self.device)

    def custom_loop(self):
        """Custom training and testing loop.
    Args:
      train_dist_dataset: Training dataset created using strategy.
      test_dist_dataset: Testing dataset created using strategy.
      strategy: Distribution strategy.
    Returns:
      train_loss, train_accuracy, test_loss, test_accuracy
    """
        def distributed_train_epoch(epoch_num):

            summary_loss = AverageMeter()
            acc_score = ACCMeter()
            self.model.train()

            if cfg.MODEL.freeze_bn:
                for m in self.model.modules():
                    if isinstance(m, nn.BatchNorm2d):
                        m.eval()
                        if cfg.MODEL.freeze_bn_affine:
                            m.weight.requires_grad = False
                            m.bias.requires_grad = False
            for step in range(self.train_ds.size):

                if epoch_num < 10:
                    ###excute warm up in the first epoch
                    if self.warup_step > 0:
                        if self.iter_num < self.warup_step:
                            for param_group in self.optimizer.param_groups:
                                param_group['lr'] = self.iter_num / float(
                                    self.warup_step) * self.init_lr
                                lr = param_group['lr']

                            logger.info('warm up with learning rate: [%f]' %
                                        (lr))

                start = time.time()

                images, data, target = self.train_ds()
                images = torch.from_numpy(images).to(self.device).float()
                data = torch.from_numpy(data).to(self.device).float()
                target = torch.from_numpy(target).to(self.device).float()

                batch_size = data.shape[0]

                output = self.model(images, data)

                current_loss = self.criterion(output, target)

                summary_loss.update(current_loss.detach().item(), batch_size)
                acc_score.update(target, output)
                self.optimizer.zero_grad()

                if cfg.TRAIN.mix_precision:
                    with amp.scale_loss(current_loss,
                                        self.optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    current_loss.backward()

                self.optimizer.step()
                if cfg.MODEL.ema:
                    self.ema.update()
                self.iter_num += 1
                time_cost_per_batch = time.time() - start

                images_per_sec = cfg.TRAIN.batch_size / time_cost_per_batch

                if self.iter_num % cfg.TRAIN.log_interval == 0:

                    log_message = '[fold %d], '\
                                  'Train Step %d, ' \
                                  'summary_loss: %.6f, ' \
                                  'accuracy: %.6f, ' \
                                  'time: %.6f, '\
                                  'speed %d images/persec'% (
                                      self.fold,
                                      step,
                                      summary_loss.avg,
                                      acc_score.avg,
                                      time.time() - start,
                                      images_per_sec)
                    logger.info(log_message)

            if cfg.TRAIN.SWA > 0 and epoch_num >= cfg.TRAIN.SWA:
                self.optimizer.update_swa()

            return summary_loss, acc_score

        def distributed_test_epoch(epoch_num):
            summary_loss = AverageMeter()
            acc_score = ACCMeter()

            self.model.eval()
            t = time.time()
            with torch.no_grad():
                for step in range(self.val_ds.size):
                    images, data, target = self.train_ds()
                    images = torch.from_numpy(images).to(self.device).float()
                    data = torch.from_numpy(data).to(self.device).float()
                    target = torch.from_numpy(target).to(self.device).float()
                    batch_size = data.shape[0]

                    output = self.model(images, data)
                    loss = self.criterion(output, target)

                    summary_loss.update(loss.detach().item(), batch_size)
                    acc_score.update(target, output)

                    if step % cfg.TRAIN.log_interval == 0:

                        log_message = '[fold %d], '\
                                      'Val Step %d, ' \
                                      'summary_loss: %.6f, ' \
                                      'acc: %.6f, ' \
                                      'time: %.6f' % (
                                      self.fold,step, summary_loss.avg, acc_score.avg, time.time() - t)

                        logger.info(log_message)

            return summary_loss, acc_score

        for epoch in range(self.epochs):

            for param_group in self.optimizer.param_groups:
                lr = param_group['lr']
            logger.info('learning rate: [%f]' % (lr))
            t = time.time()

            summary_loss, acc_score = distributed_train_epoch(epoch)

            train_epoch_log_message = '[fold %d], '\
                                      '[RESULT]: Train. Epoch: %d,' \
                                      ' summary_loss: %.5f,' \
                                      ' acuracy: %.5f,' \
                                      ' time:%.5f' % (
                                      self.fold,epoch, summary_loss.avg,acc_score.avg, (time.time() - t))
            logger.info(train_epoch_log_message)

            if cfg.TRAIN.SWA > 0 and epoch >= cfg.TRAIN.SWA:

                ###switch to avg model
                self.optimizer.swap_swa_sgd()

            ##switch eam weighta
            if cfg.MODEL.ema:
                self.ema.apply_shadow()

            if epoch % cfg.TRAIN.test_interval == 0:

                summary_loss, acc_score = distributed_test_epoch(epoch)

                val_epoch_log_message = '[fold %d], '\
                                        '[RESULT]: VAL. Epoch: %d,' \
                                        ' summary_loss: %.5f,' \
                                        ' accuracy: %.5f,' \
                                        ' time:%.5f' % (
                                         self.fold,epoch, summary_loss.avg,acc_score.avg, (time.time() - t))
                logger.info(val_epoch_log_message)

            self.scheduler.step()
            # self.scheduler.step(final_scores.avg)

            #### save model
            if not os.access(cfg.MODEL.model_path, os.F_OK):
                os.mkdir(cfg.MODEL.model_path)
            ###save the best auc model

            #### save the model every end of epoch
            current_model_saved_name = './models/fold%d_epoch_%d_val_loss%.6f.pth' % (
                self.fold, epoch, summary_loss.avg)

            logger.info('A model saved to %s' % current_model_saved_name)
            torch.save(self.model.state_dict(), current_model_saved_name)

            ####switch back
            if cfg.MODEL.ema:
                self.ema.restore()

            # save_checkpoint({
            #           'state_dict': self.model.state_dict(),
            #           },iters=epoch,tag=current_model_saved_name)

            if cfg.TRAIN.SWA > 0 and epoch > cfg.TRAIN.SWA:
                ###switch back to plain model to train next epoch
                self.optimizer.swap_swa_sgd()

    def load_weight(self):
        if cfg.MODEL.pretrained_model is not None:
            state_dict = torch.load(cfg.MODEL.pretrained_model,
                                    map_location=self.device)
            self.model.load_state_dict(state_dict, strict=False)
Esempio n. 5
0
def train(train_df, CONFIG):
    # set-up
    seed_everything(CONFIG['SEED'])
    torch.manual_seed(CONFIG['TORCH_SEED'])
    mlflow.log_params(CONFIG)

    TRAIN_LEN = len(train_df)
    train_dataset = TweetDataset(train_df, CONFIG)
    CRITERION = define_criterion(CONFIG)

    folds = StratifiedKFold(n_splits=CONFIG["FOLD"],
                            shuffle=True,
                            random_state=CONFIG["SEED"])
    for n_fold, (train_idx, valid_idx) in enumerate(
            folds.split(train_dataset.df['textID'],
                        train_dataset.df['sentiment'])):
        if n_fold != CONFIG["FOLD_NUM"]:
            continue

        ## DataLoaderの定義
        train = torch.utils.data.Subset(train_dataset, train_idx)
        valid = torch.utils.data.Subset(train_dataset, valid_idx)

        DATA_IN_EPOCH = len(train)
        TOTAL_DATA = DATA_IN_EPOCH * CONFIG["EPOCHS"]
        T_TOTAL = int(CONFIG["EPOCHS"] * DATA_IN_EPOCH /
                      CONFIG["TRAIN_BATCH_SIZE"])

        ## modelとoptimizerの初期化
        model = build_model(CONFIG)
        model.to(DEVICE)
        model.train()

        ## From 20/05/17
        param_optimizer = list(model.named_parameters())
        bert_params = [n for n, p in param_optimizer if "bert" in n]
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']

        optimizer_grouped_parameters = [
            ## BERT param
            {
                'params': [
                    p for n, p in param_optimizer
                    if (not any(nd in n
                                for nd in no_decay)) and (n in bert_params)
                ],
                'weight_decay':
                CONFIG["WEIGHT_DECAY"],
                'lr':
                CONFIG['LR'] * 1,
            },
            {
                'params': [
                    p for n, p in param_optimizer
                    if (any(nd in n for nd in no_decay)) and (n in bert_params)
                ],
                'weight_decay':
                0.0,
                'lr':
                CONFIG['LR'] * 1,
            },
            ## Other param
            {
                'params': [
                    p for n, p in param_optimizer
                    if (not any(nd in n
                                for nd in no_decay)) and (n not in bert_params)
                ],
                'weight_decay':
                CONFIG["WEIGHT_DECAY"],
                'lr':
                CONFIG['LR'] * 1,
            },
            {
                'params': [
                    p for n, p in param_optimizer
                    if (any(nd in n
                            for nd in no_decay)) and (n not in bert_params)
                ],
                'weight_decay':
                0.0,
                'lr':
                CONFIG['LR'] * 1,
            },
        ]
        optimizer = torch.optim.AdamW(optimizer_grouped_parameters, eps=4e-5)
        if CONFIG['SWA']:
            optimizer = SWA(optimizer)

        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=int(CONFIG["WARMUP"] * T_TOTAL),
            num_training_steps=T_TOTAL)

        train_sampler = SentimentBalanceSampler(train, CONFIG)

        train_loader = DataLoader(train,
                                  batch_size=CONFIG["TRAIN_BATCH_SIZE"],
                                  shuffle=False,
                                  sampler=train_sampler,
                                  collate_fn=TweetCollate(CONFIG),
                                  num_workers=1)

        valid_loader = DataLoader(
            valid,
            batch_size=CONFIG["VALID_BATCH_SIZE"],
            shuffle=False,
            # sampler     = valid_sampler,
            collate_fn=TweetCollate(CONFIG),
            num_workers=1)

        n_data = 0
        n_e_data = 0
        best_val = 0.0
        best_val_neu = 0.0
        best_val_pos = 0.0
        best_val_neg = 0.0

        t_batch = 0
        while n_data < TOTAL_DATA:
            print(f"Epoch : {int(n_data/DATA_IN_EPOCH)}")

            n_batch = 0
            loss_list = []
            jac_token_list = []
            jac_text_list = []
            jac_sentiment_list = []
            jac_cl_text_list = []

            output_list = []
            target_list = []

            for batch in tqdm(train_loader):
                textID = batch['textID']
                text = batch['text']
                sentiment = batch['sentiment']
                cl_text = batch['cl_text']
                selected_text = batch['selected_text']
                cl_selected_text = batch['cl_selected_text']
                text_idx = batch['text_idx']
                offsets = batch['offsets']

                tokenized_text = batch['tokenized_text'].to(DEVICE)
                mask = batch['mask'].to(DEVICE)
                mask_out = batch['mask_out'].to(DEVICE)
                token_type_ids = batch['token_type_ids'].to(DEVICE)
                weight = batch['weight'].to(DEVICE)
                target = batch['target'].to(DEVICE)

                ep = int(n_data / DATA_IN_EPOCH)
                n_data += len(textID)
                n_e_data += len(textID)
                n_batch += 1
                t_batch += 1

                model.zero_grad()
                # optimizer.zero_grad()
                output = model(input_ids=tokenized_text,
                               attention_mask=mask,
                               token_type_ids=token_type_ids,
                               mask_out=mask_out)

                loss = CRITERION(output, target)

                loss = loss * weight
                loss.mean().backward()
                loss = loss.detach().cpu().numpy().tolist()

                optimizer.step()

                if t_batch < T_TOTAL * 0.50:
                    scheduler.step()

                loss_list.extend(loss)

                output = output.detach().cpu().numpy()
                target = target.detach().cpu().numpy()

                jac = calc_jaccard(output, batch, CONFIG)

                jac_token_list.extend(jac['jaccard_token'].tolist())
                jac_cl_text_list.extend(jac['jaccard_cl_text'].tolist())
                jac_text_list.extend(jac['jaccard_text'].tolist())
                jac_sentiment_list.extend(sentiment)

                if ((((ep > 0) &
                      (n_batch %
                       (int(5 * 32 / CONFIG["TRAIN_BATCH_SIZE"])) == 0)) |
                     ((n_batch %
                       (int(50 * 32 / CONFIG["TRAIN_BATCH_SIZE"])) == 0)) |
                     (n_data >= TOTAL_DATA)) and
                    (CONFIG['SWA'] and ep > 0 and t_batch >= T_TOTAL * 0.50)):
                    optimizer.update_swa()

                if (
                    ((ep > 0) &
                     (n_batch %
                      (int(50 * 32 / CONFIG["TRAIN_BATCH_SIZE"])) == 0)) |
                    ((n_batch %
                      (int(50 * 32 / CONFIG["TRAIN_BATCH_SIZE"])) == 0)) |
                    (n_data >= TOTAL_DATA)
                ):  # ((n_data>=0)&(n_data<=1600)|(n_data>=21000)&(n_data<=23000))&

                    if CONFIG['SWA'] and ep > 0 and t_batch >= T_TOTAL * 0.50:
                        # optimizer.update_swa()
                        optimizer.swap_swa_sgd()

                    val = create_valid(model, valid_loader, CONFIG)

                    trn_loss = np.array(loss_list).mean()
                    trn_jac_token = np.array(jac_token_list).mean()
                    trn_jac_cl_text = np.array(jac_cl_text_list).mean()

                    trn_jac_text = np.array(jac_text_list).mean()
                    trn_jac_text_neu = np.array(jac_text_list)[np.array(
                        jac_sentiment_list) == 'neutral'].mean()
                    trn_jac_text_pos = np.array(jac_text_list)[np.array(
                        jac_sentiment_list) == 'positive'].mean()
                    trn_jac_text_neg = np.array(jac_text_list)[np.array(
                        jac_sentiment_list) == 'negative'].mean()

                    val_loss = val['loss'].mean()
                    val_jac_token = val['jaccard_token'].mean()
                    val_jac_cl_text = val['jaccard_cl_text'].mean()

                    val_jac_text = val['jaccard_text'].mean()
                    val_jac_text_neu = val['jaccard_text'][val['sentiment'] ==
                                                           'neutral'].mean()
                    val_jac_text_pos = val['jaccard_text'][val['sentiment'] ==
                                                           'positive'].mean()
                    val_jac_text_neg = val['jaccard_text'][val['sentiment'] ==
                                                           'negative'].mean()

                    loss_list = []
                    jac_token_list = []
                    jac_cl_text_list = []
                    jac_text_list = []
                    jac_sentiment_list = []

                    # mlflow
                    metrics = {
                        "lr": optimizer.param_groups[0]['lr'],
                        "trn_loss": trn_loss,
                        "trn_jac_text_neu": trn_jac_text_neu,
                        "trn_jac_text_pos": trn_jac_text_pos,
                        "trn_jac_text_neg": trn_jac_text_neg,
                        "trn_jac_text": trn_jac_text,
                        "val_loss": val_loss,
                        "val_jac_text_neu": val_jac_text_neu,
                        "val_jac_text_pos": val_jac_text_pos,
                        "val_jac_text_neg": val_jac_text_neg,
                        "val_jac_text": val_jac_text,
                    }
                    mlflow.log_metrics(metrics, step=n_data)

                    if CONFIG['SWA'] and t_batch < T_TOTAL * 0.50:
                        pass
                    else:
                        if best_val < val_jac_text:
                            best_val = val_jac_text
                            best_model = copy.deepcopy(model)

                    if CONFIG['SWA'] and ep > 0 and t_batch >= T_TOTAL * 0.50:
                        optimizer.swap_swa_sgd()

                if n_e_data >= DATA_IN_EPOCH:
                    n_e_data -= DATA_IN_EPOCH

                if n_data >= TOTAL_DATA:
                    filepath = os.path.join(FILE_DIR, OUTPUT_DIR, "model.pth")
                    torch.save(best_model.state_dict(), filepath)

                    # mlflow
                    mlflow.log_artifact(filepath)
                    break
Esempio n. 6
0
def fit(
    model,
    train_dataset,
    val_dataset,
    optimizer_name="adam",
    samples_per_player=0,
    epochs=50,
    batch_size=32,
    val_bs=32,
    warmup_prop=0.1,
    lr=1e-3,
    acc_steps=1,
    swa_first_epoch=50,
    num_classes_aux=0,
    aux_mode="sigmoid",
    verbose=1,
    first_epoch_eval=0,
    device="cuda",
):
    """
    Fitting function for the classification task.

    Args:
        model (torch model): Model to train.
        train_dataset (torch dataset): Dataset to train with.
        val_dataset (torch dataset): Dataset to validate with.
        optimizer_name (str, optional): Optimizer name. Defaults to 'adam'.
        samples_per_player (int, optional): Number of images to use per player. Defaults to 0.
        epochs (int, optional): Number of epochs. Defaults to 50.
        batch_size (int, optional): Training batch size. Defaults to 32.
        val_bs (int, optional): Validation batch size. Defaults to 32.
        warmup_prop (float, optional): Warmup proportion. Defaults to 0.1.
        lr (float, optional): Learning rate. Defaults to 1e-3.
        acc_steps (int, optional): Accumulation steps. Defaults to 1.
        swa_first_epoch (int, optional): Epoch to start applying SWA from. Defaults to 50.
        num_classes_aux (int, optional): Number of auxiliary classes. Defaults to 0.
        aux_mode (str, optional): Mode for auxiliary classification. Defaults to 'sigmoid'.
        verbose (int, optional): Period (in epochs) to display logs at. Defaults to 1.
        first_epoch_eval (int, optional): Epoch to start evaluating at. Defaults to 0.
        device (str, optional): Device for torch. Defaults to "cuda".

    Returns:
        numpy array [len(val_dataset)]: Last predictions on the validation data.
        numpy array [len(val_dataset) x num_classes_aux]: Last aux predictions on the val data.
    """

    optimizer = define_optimizer(optimizer_name, model.parameters(), lr=lr)

    if swa_first_epoch <= epochs:
        optimizer = SWA(optimizer)

    loss_fct = nn.BCEWithLogitsLoss()
    loss_fct_aux = nn.BCEWithLogitsLoss(
    ) if aux_mode == "sigmoid" else nn.CrossEntropyLoss()
    aux_loss_weight = 1 if num_classes_aux else 0

    if samples_per_player:
        sampler = PlayerSampler(
            RandomSampler(train_dataset),
            train_dataset.players,
            batch_size=batch_size,
            drop_last=True,
            samples_per_player=samples_per_player,
        )
        train_loader = DataLoader(
            train_dataset,
            batch_sampler=sampler,
            num_workers=NUM_WORKERS,
            pin_memory=True,
        )

        print(
            f"Using {len(train_loader)} out of {len(train_dataset) // batch_size} "
            f"batches by limiting to {samples_per_player} samples per player.\n"
        )
    else:
        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=NUM_WORKERS,
            pin_memory=True,
        )

    val_loader = DataLoader(
        val_dataset,
        batch_size=val_bs,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True,
    )

    num_training_steps = int(epochs * len(train_loader))
    num_warmup_steps = int(warmup_prop * num_training_steps)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps,
                                                num_training_steps)

    for epoch in range(epochs):
        model.train()

        start_time = time.time()
        optimizer.zero_grad()

        avg_loss = 0

        if epoch + 1 > swa_first_epoch:
            optimizer.swap_swa_sgd()

        for batch in train_loader:
            images = batch[0].to(device)
            y_batch = batch[1].to(device).view(-1).float()
            y_batch_aux = batch[2].to(device).float()
            y_batch_aux = y_batch_aux.float(
            ) if aux_mode == "sigmoid" else y_batch_aux.long()

            y_pred, y_pred_aux = model(images)

            loss = loss_fct(y_pred.view(-1), y_batch)
            if aux_loss_weight:
                loss += aux_loss_weight * loss_fct_aux(y_pred_aux, y_batch_aux)
            loss.backward()

            avg_loss += loss.item() / len(train_loader)
            optimizer.step()
            scheduler.step()
            for param in model.parameters():
                param.grad = None

        if epoch + 1 >= swa_first_epoch:
            optimizer.update_swa()
            optimizer.swap_swa_sgd()

        preds = np.empty(0)
        preds_aux = np.empty((0, num_classes_aux))
        model.eval()
        avg_val_loss, auc, scores_aux = 0., 0., 0.
        if epoch + 1 >= first_epoch_eval or epoch + 1 == epochs:
            with torch.no_grad():
                for batch in val_loader:
                    images = batch[0].to(device)
                    y_batch = batch[1].to(device).view(-1).float()
                    y_aux = batch[2].to(device).float()
                    y_batch_aux = y_aux.float(
                    ) if aux_mode == "sigmoid" else y_aux.long()

                    y_pred, y_pred_aux = model(images)

                    loss = loss_fct(y_pred.detach().view(-1), y_batch)
                    if aux_loss_weight:
                        loss += aux_loss_weight * loss_fct_aux(
                            y_pred_aux.detach(), y_batch_aux)

                    avg_val_loss += loss.item() / len(val_loader)

                    y_pred = torch.sigmoid(y_pred).view(-1)
                    preds = np.concatenate(
                        [preds, y_pred.detach().cpu().numpy()])

                    if num_classes_aux:
                        y_pred_aux = (y_pred_aux.sigmoid() if aux_mode
                                      == "sigmoid" else y_pred_aux.softmax(-1))
                        preds_aux = np.concatenate(
                            [preds_aux,
                             y_pred_aux.detach().cpu().numpy()])

            auc = roc_auc_score(val_dataset.labels, preds)

            if num_classes_aux:
                if aux_mode == "sigmoid":
                    scores_aux = np.round(
                        [
                            roc_auc_score(val_dataset.aux_labels[:, i],
                                          preds_aux[:, i])
                            for i in range(num_classes_aux)
                        ],
                        3,
                    ).tolist()
                else:
                    scores_aux = np.round(
                        [
                            roc_auc_score((val_dataset.aux_labels
                                           == i).astype(int), preds_aux[:, i])
                            for i in range(num_classes_aux)
                        ],
                        3,
                    ).tolist()
            else:
                scores_aux = 0

        elapsed_time = time.time() - start_time
        if (epoch + 1) % verbose == 0:
            elapsed_time = elapsed_time * verbose
            lr = scheduler.get_last_lr()[0]
            print(
                f"Epoch {epoch + 1:02d}/{epochs:02d} \t lr={lr:.1e}\t t={elapsed_time:.0f}s \t"
                f"loss={avg_loss:.3f}",
                end="\t",
            )

            if epoch + 1 >= first_epoch_eval:
                print(
                    f"val_loss={avg_val_loss:.3f} \t auc={auc:.3f}\t aucs_aux={scores_aux}"
                )
            else:
                print("")

    del val_loader, train_loader, y_pred
    torch.cuda.empty_cache()

    return preds, preds_aux
Esempio n. 7
0
    def train(self, train_inputs):
        config = self.config.fitting
        model = train_inputs['model']
        train_data = train_inputs['train_data']
        dev_data = train_inputs['dev_data']
        epoch_start = train_inputs['epoch_start']

        train_steps = int((len(train_data) + config.batch_size - 1) / config.batch_size)
        train_dataloader = DataLoader(train_data,
                                      batch_size=config.batch_size,
                                      collate_fn=self.get_collate_fn('train'),
                                      shuffle=True)
        params_lr = []
        for key, value in model.get_params().items():
            if key in config.lr:
                params_lr.append({"params": value, 'lr': config.lr[key]})
        optimizer = torch.optim.Adam(params_lr)
        optimizer = SWA(optimizer)

        early_stopping = EarlyStopping(model, ROOT_WEIGHT, mode='max', patience=3)
        learning_schedual = LearningSchedual(optimizer, config.epochs, config.end_epoch, train_steps, config.lr)

        aux = ModelAux(self.config, train_steps)
        moving_log = MovingData(window=100)

        ending_flag = False
        detach_flag = False
        swa_flag = False
        fgm = FGM(model)
        for epoch in range(epoch_start, config.epochs):
            for step, (inputs, targets, others) in enumerate(train_dataloader):
                inputs = dict([(key, value[0].cuda() if value[1] else value[0]) for key, value in inputs.items()])
                targets = dict([(key, value.cuda()) for key, value in targets.items()])
                if epoch > 0 and step == 0:
                    model.detach_ptm(False)
                    detach_flag = False
                if epoch == 0 and step == 0:
                    model.detach_ptm(True)
                    detach_flag = True
                # train ================================================================================================
                preds = model(inputs, en_decode=config.verbose)
                loss = model.cal_loss(preds, targets, inputs['mask'])
                loss['back'].backward()

                # 对抗训练
                if (not detach_flag) and config.en_fgm:
                    fgm.attack(emb_name='word_embeddings')  # 在embedding上添加对抗扰动
                    preds_adv = model(inputs, en_decode=False)
                    loss_adv = model.cal_loss(preds_adv, targets, inputs['mask'])
                    loss_adv['back'].backward()  # 反向传播,并在正常的grad基础上,累加对抗训练的梯度
                    fgm.restore(emb_name='word_embeddings')  # 恢复embedding参数

                # torch.nn.utils.clip_grad_norm(model.parameters(), 1)
                optimizer.step()
                optimizer.zero_grad()
                with torch.no_grad():
                    logs = {}
                    if config.verbose:
                        pred_entity_point = model.find_entity(preds['pred'], others['raw_text'])
                        cn, pn, tn = self.calculate_f1(pred_entity_point, others['raw_entity'])
                        metrics_data = {'loss': loss['show'].cpu().numpy(), 'sampled_num': 1,
                                        'correct_num': cn, 'pred_num': pn,
                                        'true_num': tn}
                        moving_data = moving_log(epoch * train_steps + step, metrics_data)
                        logs['loss'] = moving_data['loss'] / moving_data['sampled_num']
                        logs['precise'], logs['recall'], logs['f1'] = calculate_f1(moving_data['correct_num'],
                                                                                   moving_data['pred_num'],
                                                                                   moving_data['true_num'],
                                                                                   verbose=True)
                    else:
                        metrics_data = {'loss': loss['show'].cpu().numpy(), 'sampled_num': 1}
                        moving_data = moving_log(epoch * train_steps + step, metrics_data)
                        logs['loss'] = moving_data['loss'] / moving_data['sampled_num']
                    # update lr
                    lr_data = learning_schedual.update_lr(epoch, step)
                    logs.update(lr_data)

                    if step + 1 == train_steps:
                        model.eval()
                        aux.new_line()

                        # dev ==========================================================================================

                        eval_inputs = {'model': model,
                                       'data': dev_data,
                                       'type_data': 'dev',
                                       'outfile': train_inputs['dev_res_file']}
                        dev_result = self.eval(eval_inputs)
                        logs['dev_loss'] = dev_result['loss']
                        logs['dev_precise'] = dev_result['precise']
                        logs['dev_recall'] = dev_result['recall']
                        logs['dev_f1'] = dev_result['f1']
                        if logs['dev_f1'] > 0.80:
                            torch.save(model.state_dict(),
                                       "{}/auto_save_{:.6f}.ckpt".format(ROOT_WEIGHT, logs['dev_f1']))
                        if (epoch > 3 or swa_flag) and config.en_swa:
                            optimizer.update_swa()
                            swa_flag = True
                        early_stop, best_score = early_stopping(logs['dev_f1'])

                        # test =========================================================================================
                        if (epoch + 1 == config.epochs and step + 1 == train_steps) or early_stop:
                            ending_flag = True
                            if swa_flag:
                                optimizer.swap_swa_sgd()
                                optimizer.bn_update(train_dataloader, model)

                        model.train()
                aux.show_log(epoch, step, logs)
                if ending_flag:
                    return best_score
Esempio n. 8
0
def main():

    maxIOU = 0.0
    assert torch.cuda.is_available()
    torch.backends.cudnn.benchmark = True
    model_fname = '../data/model_swa_8/deeplabv3_{0}_epoch%d.pth'.format(
        'crops')
    focal_loss = FocalLoss2d()
    train_dataset = CropSegmentation(train=True, crop_size=args.crop_size)
    #     test_dataset = CropSegmentation(train=False, crop_size=args.crop_size)

    model = torchvision.models.segmentation.deeplabv3_resnet50(
        pretrained=False, progress=True, num_classes=5, aux_loss=True)

    if args.train:
        weight = np.ones(4)
        weight[2] = 5
        weight[3] = 5
        w = torch.FloatTensor(weight).cuda()
        criterion = nn.CrossEntropyLoss()  #ignore_index=255 weight=w
        model = nn.DataParallel(model).cuda()

        for param in model.parameters():
            param.requires_grad = True

        optimizer1 = optim.SGD(model.parameters(),
                               lr=config.lr,
                               momentum=0.9,
                               weight_decay=1e-4)
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer,
                                                   T_max=(args.epochs // 9) +
                                                   1)
        optimizer = SWA(optimizer1)

        dataset_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=args.train,
            pin_memory=True,
            num_workers=args.workers)

        max_iter = args.epochs * len(dataset_loader)
        losses = AverageMeter()
        start_epoch = 0

        if args.resume:
            if os.path.isfile(args.resume):
                print('=> loading checkpoint {0}'.format(args.resume))
                checkpoint = torch.load(args.resume)
                start_epoch = checkpoint['epoch']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                print('=> loaded checkpoint {0} (epoch {1})'.format(
                    args.resume, checkpoint['epoch']))

            else:
                print('=> no checkpoint found at {0}'.format(args.resume))

        for epoch in range(start_epoch, args.epochs):
            scheduler.step(epoch)
            model.train()
            for i, (inputs, target) in enumerate(dataset_loader):

                inputs = Variable(inputs.cuda())
                target = Variable(target.cuda())
                outputs = model(inputs)
                loss1 = focal_loss(outputs['out'], target)
                loss2 = focal_loss(outputs['aux'], target)
                loss01 = loss1 + 0.1 * loss2
                loss3 = lovasz_softmax(outputs['out'], target)
                loss4 = lovasz_softmax(outputs['aux'], target)
                loss02 = loss3 + 0.1 * loss4
                loss = loss01 + loss02
                if np.isnan(loss.item()) or np.isinf(loss.item()):
                    pdb.set_trace()

                losses.update(loss.item(), args.batch_size)

                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                if i > 10 and i % 5 == 0:
                    optimizer.update_swa()

                print('epoch: {0}\t'
                      'iter: {1}/{2}\t'
                      'loss: {loss.val:.4f} ({loss.ema:.4f})'.format(
                          epoch + 1, i + 1, len(dataset_loader), loss=losses))

            if epoch > 5:
                torch.save(
                    {
                        'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                    }, model_fname % (epoch + 1))
        optimizer.swap_swa_sgd()
        torch.save(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, model_fname % (665 + 1))
Esempio n. 9
0
class Model() :
    def __init__(self, configuration, pre_embed=None) :
        configuration = deepcopy(configuration)
        self.configuration = deepcopy(configuration)

        configuration['model']['encoder']['pre_embed'] = pre_embed

        encoder_copy = deepcopy(configuration['model']['encoder'])
        self.Pencoder = Encoder.from_params(Params(configuration['model']['encoder'])).to(device)
        self.Qencoder = Encoder.from_params(Params(encoder_copy)).to(device)

        configuration['model']['decoder']['hidden_size'] = self.Pencoder.output_size
        self.decoder = AttnDecoderQA.from_params(Params(configuration['model']['decoder'])).to(device)

        self.bsize = configuration['training']['bsize']

        self.adversary_multi = AdversaryMulti(self.decoder)

        weight_decay = configuration['training'].get('weight_decay', 1e-5)
        self.params = list(self.Pencoder.parameters()) + list(self.Qencoder.parameters()) + list(self.decoder.parameters())
        self.optim = torch.optim.Adam(self.params, weight_decay=weight_decay, amsgrad=True)
        # self.optim = torch.optim.Adagrad(self.params, lr=0.05, weight_decay=weight_decay)
        self.criterion = nn.CrossEntropyLoss()

        import time
        dirname = configuration['training']['exp_dirname']
        basepath = configuration['training'].get('basepath', 'outputs')
        self.time_str = time.ctime().replace(' ', '_')
        self.dirname = os.path.join(basepath, dirname, self.time_str)

        self.swa_settings = configuration['training']['swa']
        if self.swa_settings[0]:
            self.swa_all_optim = SWA(self.optim)
            self.running_norms = []

    @classmethod
    def init_from_config(cls, dirname, **kwargs) :
        config = json.load(open(dirname + '/config.json', 'r'))
        config.update(kwargs)
        obj = cls(config)
        obj.load_values(dirname)
        return obj

    def get_param_buffer_norms(self):
        for p in self.swa_all_optim.param_groups[0]['params']:
            param_state = self.swa_all_optim.state[p]
            if 'swa_buffer' not in param_state:
                self.swa_all_optim.update_swa()

        norms = []
        for p in np.array(self.swa_all_optim.param_groups[0]['params'])[
            [1, 2, 5, 6, 10, 11, 14, 15, 18, 20, 24, 26]]:
            param_state = self.swa_all_optim.state[p]
            buf = np.squeeze(
                param_state['swa_buffer'].cpu().numpy())
            cur_state = np.squeeze(p.data.cpu().numpy())
            norm = np.linalg.norm(buf - cur_state)
            norms.append(norm)
        if self.swa_settings[3] == 2:
            return np.max(norms)
        return np.mean(norms)

    def total_iter_num(self):
        return self.swa_all_optim.param_groups[0]['step_counter']

    def iter_for_swa_update(self, iter_num):
        return iter_num > self.swa_settings[1] \
               and iter_num % self.swa_settings[2] == 0


    def check_and_update_swa(self):
        if self.iter_for_swa_update(self.total_iter_num()):
            cur_step_diff_norm = self.get_param_buffer_norms()
            if self.swa_settings[3] == 0:
                self.swa_all_optim.update_swa()
                return
            if not self.running_norms:
                running_mean_norm = 0
            else:
                running_mean_norm = np.mean(self.running_norms)

            if cur_step_diff_norm > running_mean_norm:
                self.swa_all_optim.update_swa()
                self.running_norms = [cur_step_diff_norm]
            elif cur_step_diff_norm > 0:
                self.running_norms.append(cur_step_diff_norm)

    def train(self, train_data, train=True) :
        docs_in = train_data.P
        question_in = train_data.Q
        entity_masks_in = train_data.E
        target_in = train_data.A

        sorting_idx = get_sorting_index_with_noise_from_lengths([len(x) for x in docs_in], noise_frac=0.1)
        docs = [docs_in[i] for i in sorting_idx]
        questions = [question_in[i] for i in sorting_idx]
        entity_masks = [entity_masks_in[i] for i in sorting_idx]
        target = [target_in[i] for i in sorting_idx]
        
        self.Pencoder.train()
        self.Qencoder.train()
        self.decoder.train()

        bsize = self.bsize
        N = len(questions)
        loss_total = 0

        batches = list(range(0, N, bsize))
        batches = shuffle(batches)

        for n in tqdm(batches) :
            torch.cuda.empty_cache()
            batch_doc = docs[n:n+bsize]
            batch_ques = questions[n:n+bsize]
            batch_entity_masks = entity_masks[n:n+bsize]

            batch_doc = BatchHolder(batch_doc)
            batch_ques = BatchHolder(batch_ques)

            batch_data = BatchMultiHolder(P=batch_doc, Q=batch_ques)
            batch_data.entity_mask = torch.ByteTensor(np.array(batch_entity_masks)).to(device)

            self.Pencoder(batch_data.P)
            self.Qencoder(batch_data.Q)
            self.decoder(batch_data)

            batch_target = target[n:n+bsize]
            batch_target = torch.LongTensor(batch_target).to(device)

            ce_loss = self.criterion(batch_data.predict, batch_target)

            loss = ce_loss

            if hasattr(batch_data, 'reg_loss') :
                loss += batch_data.reg_loss

            if train :
                if self.swa_settings[0]:
                    self.check_and_update_swa()

                    self.swa_all_optim.zero_grad()
                    loss.backward()
                    self.swa_all_optim.step()
                else:
                    self.optim.zero_grad()
                    loss.backward()
                    self.optim.step()

            loss_total += float(loss.data.cpu().item())
        if self.swa_settings[0] and self.swa_all_optim.param_groups[0][
            'step_counter'] > self.swa_settings[1]:
            print("\nSWA swapping\n")
            # self.attn_optim.swap_swa_sgd()
            # self.encoder_optim.swap_swa_sgd()
            # self.decoder_optim.swap_swa_sgd()
            self.swa_all_optim.swap_swa_sgd()
            self.running_norms = []
        return loss_total*bsize/N

    def evaluate(self, data) :
        docs = data.P
        questions = data.Q
        entity_masks = data.E
        
        self.Pencoder.train()
        self.Qencoder.train()
        self.decoder.train()
        
        bsize = self.bsize
        N = len(questions)

        outputs = []
        attns = []
        scores = []
        for n in tqdm(range(0, N, bsize)) :
            torch.cuda.empty_cache()
            batch_doc = docs[n:n+bsize]
            batch_ques = questions[n:n+bsize]
            batch_entity_masks = entity_masks[n:n+bsize]

            batch_doc = BatchHolder(batch_doc)
            batch_ques = BatchHolder(batch_ques)

            batch_data = BatchMultiHolder(P=batch_doc, Q=batch_ques)
            batch_data.entity_mask = torch.ByteTensor(np.array(batch_entity_masks)).to(device)

            self.Pencoder(batch_data.P)
            self.Qencoder(batch_data.Q)
            self.decoder(batch_data)

            prediction_scores = batch_data.predict.cpu().data.numpy()
            batch_data.predict = torch.argmax(batch_data.predict, dim=-1)
            if self.decoder.use_attention :
                attn = batch_data.attn
                attns.append(attn.cpu().data.numpy())

            predict = batch_data.predict.cpu().data.numpy()
            outputs.append(predict)
            scores.append(prediction_scores)

            

        outputs = [x for y in outputs for x in y]
        attns = [x for y in attns for x in y]
        scores = [x for y in scores for x in y]

        return outputs, attns, scores

    def save_values(self, use_dirname=None, save_model=True) :
        if use_dirname is not None :
            dirname = use_dirname
        else :
            dirname = self.dirname
        os.makedirs(dirname, exist_ok=True)
        shutil.copy2(file_name, dirname + '/')
        json.dump(self.configuration, open(dirname + '/config.json', 'w'))

        if save_model :
            torch.save(self.Pencoder.state_dict(), dirname + '/encP.th')
            torch.save(self.Qencoder.state_dict(), dirname + '/encQ.th')
            torch.save(self.decoder.state_dict(), dirname + '/dec.th')

        return dirname

    def load_values(self, dirname) :
        self.Pencoder.load_state_dict(torch.load(dirname + '/encP.th'))
        self.Qencoder.load_state_dict(torch.load(dirname + '/encQ.th'))
        self.decoder.load_state_dict(torch.load(dirname + '/dec.th'))

    def permute_attn(self, data, num_perm=100) :
        docs = data.P
        questions = data.Q
        entity_masks = data.E

        self.Pencoder.train()
        self.Qencoder.train()
        self.decoder.train()

        bsize = self.bsize
        N = len(questions)

        permutations_predict = []
        permutations_diff = []

        for n in tqdm(range(0, N, bsize)) :
            torch.cuda.empty_cache()
            batch_doc = docs[n:n+bsize]
            batch_ques = questions[n:n+bsize]
            batch_entity_masks = entity_masks[n:n+bsize]

            batch_doc = BatchHolder(batch_doc)
            batch_ques = BatchHolder(batch_ques)

            batch_data = BatchMultiHolder(P=batch_doc, Q=batch_ques)
            batch_data.entity_mask = torch.ByteTensor(np.array(batch_entity_masks)).to(device)

            self.Pencoder(batch_data.P)
            self.Qencoder(batch_data.Q)
            self.decoder(batch_data)

            predict_true = batch_data.predict.clone().detach()

            batch_perms_predict = np.zeros((batch_data.P.B, num_perm))
            batch_perms_diff = np.zeros((batch_data.P.B, num_perm))

            for i in range(num_perm) :
                batch_data.permute = True
                self.decoder(batch_data)

                predict = torch.argmax(batch_data.predict, dim=-1)
                batch_perms_predict[:, i] = predict.cpu().data.numpy()
            
                predict_difference = self.adversary_multi.output_diff(batch_data.predict, predict_true)
                batch_perms_diff[:, i] = predict_difference.squeeze(-1).cpu().data.numpy()
                
            permutations_predict.append(batch_perms_predict)
            permutations_diff.append(batch_perms_diff)

        permutations_predict = [x for y in permutations_predict for x in y]
        permutations_diff = [x for y in permutations_diff for x in y]
        
        return permutations_predict, permutations_diff

    def adversarial_multi(self, data) :
        docs = data.P
        questions = data.Q
        entity_masks = data.E

        self.Pencoder.eval()
        self.Qencoder.eval()
        self.decoder.eval()

        print(self.adversary_multi.K)
        
        self.params = list(self.Pencoder.parameters()) + list(self.Qencoder.parameters()) + list(self.decoder.parameters())

        for p in self.params :
            p.requires_grad = False

        bsize = self.bsize
        N = len(questions)
        batches = list(range(0, N, bsize))

        outputs, attns, diffs = [], [], []

        for n in tqdm(batches) :
            torch.cuda.empty_cache()
            batch_doc = docs[n:n+bsize]
            batch_ques = questions[n:n+bsize]
            batch_entity_masks = entity_masks[n:n+bsize]

            batch_doc = BatchHolder(batch_doc)
            batch_ques = BatchHolder(batch_ques)

            batch_data = BatchMultiHolder(P=batch_doc, Q=batch_ques)
            batch_data.entity_mask = torch.ByteTensor(np.array(batch_entity_masks)).to(device)

            self.Pencoder(batch_data.P)
            self.Qencoder(batch_data.Q)
            self.decoder(batch_data)

            self.adversary_multi(batch_data)

            predict_volatile = torch.argmax(batch_data.predict_volatile, dim=-1)
            outputs.append(predict_volatile.cpu().data.numpy())
            
            attn = batch_data.attn_volatile
            attns.append(attn.cpu().data.numpy())

            predict_difference = self.adversary_multi.output_diff(batch_data.predict_volatile, batch_data.predict.unsqueeze(1))
            diffs.append(predict_difference.cpu().data.numpy())

        outputs = [x for y in outputs for x in y]
        attns = [x for y in attns for x in y]
        diffs = [x for y in diffs for x in y]
        
        return outputs, attns, diffs

    def gradient_mem(self, data) :
        docs = data.P
        questions = data.Q
        entity_masks = data.E

        self.Pencoder.train()
        self.Qencoder.train()
        self.decoder.train()
        
        bsize = self.bsize
        N = len(questions)

        grads = {'XxE' : [], 'XxE[X]' : [], 'H' : []}

        for n in range(0, N, bsize) :
            torch.cuda.empty_cache()
            batch_doc = docs[n:n+bsize]
            batch_ques = questions[n:n+bsize]
            batch_entity_masks = entity_masks[n:n+bsize]

            batch_doc = BatchHolder(batch_doc)
            batch_ques = BatchHolder(batch_ques)

            batch_data = BatchMultiHolder(P=batch_doc, Q=batch_ques)
            batch_data.entity_mask = torch.ByteTensor(np.array(batch_entity_masks)).to(device)

            batch_data.P.keep_grads = True
            batch_data.detach = True

            self.Pencoder(batch_data.P)
            self.Qencoder(batch_data.Q)
            self.decoder(batch_data)
            
            max_predict = torch.argmax(batch_data.predict, dim=-1)
            prob_predict = nn.Softmax(dim=-1)(batch_data.predict)

            max_class_prob = torch.gather(prob_predict, -1, max_predict.unsqueeze(-1))
            max_class_prob.sum().backward()

            g = batch_data.P.embedding.grad
            em = batch_data.P.embedding
            g1 = (g * em).sum(-1)
            
            grads['XxE[X]'].append(g1.cpu().data.numpy())
            
            g1 = (g * self.Pencoder.embedding.weight.sum(0)).sum(-1)
            grads['XxE'].append(g1.cpu().data.numpy())
            
            g1 = batch_data.P.hidden.grad.sum(-1)
            grads['H'].append(g1.cpu().data.numpy())


        for k in grads :
            grads[k] = [x for y in grads[k] for x in y]
                    
        return grads       

    def remove_and_run(self, data) :
        docs = data.P
        questions = data.Q
        entity_masks = data.E

        self.Pencoder.train()
        self.Qencoder.train()
        self.decoder.train()
        
        bsize = self.bsize
        N = len(questions)
        output_diffs = []

        for n in tqdm(range(0, N, bsize)) :
            torch.cuda.empty_cache()
            batch_doc = docs[n:n+bsize]
            batch_ques = questions[n:n+bsize]
            batch_entity_masks = entity_masks[n:n+bsize]

            batch_doc = BatchHolder(batch_doc)
            batch_ques = BatchHolder(batch_ques)

            batch_data = BatchMultiHolder(P=batch_doc, Q=batch_ques)
            batch_data.entity_mask = torch.ByteTensor(np.array(batch_entity_masks)).to(device)

            self.Pencoder(batch_data.P)
            self.Qencoder(batch_data.Q)
            self.decoder(batch_data)

            po = np.zeros((batch_data.P.B, batch_data.P.maxlen))

            for i in range(1, batch_data.P.maxlen - 1) :
                batch_doc = BatchHolder(docs[n:n+bsize])

                batch_doc.seq = torch.cat([batch_doc.seq[:, :i], batch_doc.seq[:, i+1:]], dim=-1)
                batch_doc.lengths = batch_doc.lengths - 1
                batch_doc.masks = torch.cat([batch_doc.masks[:, :i], batch_doc.masks[:, i+1:]], dim=-1)

                batch_data_loop = BatchMultiHolder(P=batch_doc, Q=batch_ques)
                batch_data_loop.entity_mask = torch.ByteTensor(np.array(batch_entity_masks)).to(device)

                self.Pencoder(batch_data_loop.P)
                self.decoder(batch_data_loop)

                predict_difference = self.adversary_multi.output_diff(batch_data_loop.predict, batch_data.predict)

                po[:, i] = predict_difference.squeeze(-1).cpu().data.numpy()

            output_diffs.append(po)

        output_diffs = [x for y in output_diffs for x in y]
        
        return output_diffs
Esempio n. 10
0
def train(opt):
    model = www_model_jamo_vertical.STR(opt, device)
    print(
        'model parameters. height {}, width {}, num of fiducial {}, input channel {}, output channel {}, hidden size {},     batch max length {}'
        .format(opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel,
                opt.output_channel, opt.hidden_size, opt.batch_max_length))

    # weight initialization
    for name, param, in model.named_parameters():
        if 'localization_fc2' in name:
            print(f'Skip {name} as it is already initializaed')
            continue
        try:
            if 'bias' in name:
                init.constant_(param, 0.0)
            elif 'weight' in name:
                init.kaiming_normal_(param)

        except Exception as e:
            if 'weight' in name:
                param.data.fill_(1)
            continue

    # load pretrained model
    if opt.saved_model != '':
        base_path = './models'
        print(
            f'looking for pretrained model from {os.path.join(base_path, opt.saved_model)}'
        )

        try:
            model.load_state_dict(
                torch.load(os.path.join(base_path, opt.saved_model)))
            print('loading complete ')
        except Exception as e:
            print(e)
            print('coud not find model')

    #data parallel for multi GPU
    model = torch.nn.DataParallel(model).to(device)
    model.train()

    # loss
    criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(
        device)  #ignore [GO] token = ignore index 0
    log_avg = utils.Averager()

    # filter that only require gradient descent
    filtered_parameters = []
    params_num = []
    for p in filter(lambda p: p.requires_grad, model.parameters()):
        filtered_parameters.append(p)
        params_num.append(np.prod(p.size()))
    print('Tranable params : ', sum(params_num))

    # optimizer

    #     base_opt = optim.Adadelta(filtered_parameters, lr= opt.lr, rho = opt.rho, eps = opt.eps)
    base_opt = torch.optim.Adam(filtered_parameters, lr=0.001)
    optimizer = SWA(base_opt)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode='max',
                                                           verbose=True,
                                                           patience=2,
                                                           factor=0.5)
    #     optimizer = adabound.AdaBound(filtered_parameters, lr=1e-3, final_lr=0.1)

    # opt log
    with open(f'./models/{opt.experiment_name}/opt.txt', 'a') as opt_file:
        opt_log = '---------------------Options-----------------\n'
        args = vars(opt)
        for k, v in args.items():
            opt_log += f'{str(k)} : {str(v)}\n'
        opt_log += '---------------------------------------------\n'
        opt_file.write(opt_log)

    #start training
    start_time = time.time()
    best_accuracy = -1
    best_norm_ED = -1
    swa_count = 0

    for n_epoch, epoch in enumerate(range(opt.num_epoch)):
        for n_iter, data_point in enumerate(data_loader):

            image_tensors, top, mid, bot = data_point

            image = image_tensors.to(device)
            text_top, length_top = top_converter.encode(
                top, batch_max_length=opt.batch_max_length)
            text_mid, length_mid = middle_converter.encode(
                mid, batch_max_length=opt.batch_max_length)
            text_bot, length_bot = bottom_converter.encode(
                bot, batch_max_length=opt.batch_max_length)
            batch_size = image.size(0)

            pred_top, pred_mid, pred_bot = model(image, text_top[:, :-1],
                                                 text_mid[:, :-1],
                                                 text_bot[:, :-1])

            #             cost_top = criterion(pred_top.view(-1, pred_top.shape[-1]), text_top[:, 1:].contiguous().view(-1))
            #             cost_mid = criterion(pred_mid.view(-1, pred_mid.shape[-1]), text_mid[:, 1:].contiguous().view(-1))
            #             cost_bot = criterion(pred_bot.view(-1, pred_bot.shape[-1]), text_bot[:, 1:].contiguous().view(-1))
            if n_iter % 2 == 0:

                cost_top = utils.reduced_focal_loss(
                    pred_top.view(-1, pred_top.shape[-1]),
                    text_top[:, 1:].contiguous().view(-1),
                    ignore_index=0,
                    gamma=2,
                    alpha=0.25,
                    threshold=0.5)
                cost_mid = utils.reduced_focal_loss(
                    pred_mid.view(-1, pred_mid.shape[-1]),
                    text_mid[:, 1:].contiguous().view(-1),
                    ignore_index=0,
                    gamma=2,
                    alpha=0.25,
                    threshold=0.5)
                cost_bot = utils.reduced_focal_loss(
                    pred_bot.view(-1, pred_bot.shape[-1]),
                    text_bot[:, 1:].contiguous().view(-1),
                    ignore_index=0,
                    gamma=2,
                    alpha=0.25,
                    threshold=0.5)
            else:
                cost_top = utils.CB_loss(text_top[:, 1:].contiguous().view(-1),
                                         pred_top.view(-1, pred_top.shape[-1]),
                                         top_per_cls, opt.top_n_cls, 'focal',
                                         0.999, 0.5)
                cost_mid = utils.CB_loss(text_mid[:, 1:].contiguous().view(-1),
                                         pred_mid.view(-1, pred_mid.shape[-1]),
                                         mid_per_cls, opt.middle_n_cls,
                                         'focal', 0.999, 0.5)
                cost_bot = utils.CB_loss(text_bot[:, 1:].contiguous().view(-1),
                                         pred_bot.view(-1, pred_bot.shape[-1]),
                                         bot_per_cls, opt.bottom_n_cls,
                                         'focal', 0.999, 0.5)
            cost = cost_top * 0.33 + cost_mid * 0.33 + cost_bot * 0.33

            loss_avg = utils.Averager()
            loss_avg.add(cost)

            model.zero_grad()
            cost.backward()
            torch.nn.utils.clip_grad_norm_(
                model.parameters(), opt.grad_clip)  #gradient clipping with 5
            optimizer.step()
            print(loss_avg.val())

            #validation
            if (n_iter % opt.valInterval == 0) & (n_iter != 0):
                elapsed_time = time.time() - start_time
                with open(f'./models/{opt.experiment_name}/log_train.txt',
                          'a') as log:
                    model.eval()
                    with torch.no_grad():
                        valid_loss, current_accuracy, current_norm_ED, pred_top_str, pred_mid_str, pred_bot_str, label_top, label_mid, label_bot, infer_time, length_of_data = evaluate.validation_jamo(
                            model, criterion, valid_loader, top_converter,
                            middle_converter, bottom_converter, opt)
                    scheduler.step(current_accuracy)
                    model.train()

                    present_time = time.localtime()
                    loss_log = f'[epoch : {n_epoch}/{opt.num_epoch}] [iter : {n_iter*opt.batch_size} / {int(len(data) * 0.95)}]\n' + f'Train loss : {loss_avg.val():0.5f}, Valid loss : {valid_loss:0.5f}, Elapsed time : {elapsed_time:0.5f}, Present time : {present_time[1]}/{present_time[2]}, {present_time[3]+9} : {present_time[4]}'
                    loss_avg.reset()

                    current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"current_norm_ED":17s}: {current_norm_ED:0.2f}'

                    #keep the best
                    if current_accuracy > best_accuracy:
                        best_accuracy = current_accuracy
                        torch.save(
                            model.module.state_dict(),
                            f'./models/{opt.experiment_name}/best_accuracy_{round(current_accuracy,2)}.pth'
                        )

                    if current_norm_ED > best_norm_ED:
                        best_norm_ED = current_norm_ED
                        torch.save(
                            model.module.state_dict(),
                            f'./models/{opt.experiment_name}/best_norm_ED.pth')

                    best_model_log = f'{"Best accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'
                    loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}'
                    print(loss_model_log)
                    log.write(loss_model_log + '\n')

                    dashed_line = '-' * 80
                    head = f'{"Ground Truth":25s} | {"Prediction" :25s}| T/F'
                    predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n'

                    random_idx = np.random.choice(range(len(label_top)),
                                                  size=5,
                                                  replace=False)
                    label_concat = np.concatenate([
                        np.asarray(label_top).reshape(1, -1),
                        np.asarray(label_mid).reshape(1, -1),
                        np.asarray(label_bot).reshape(1, -1)
                    ],
                                                  axis=0).reshape(3, -1)
                    pred_concat = np.concatenate([
                        np.asarray(pred_top_str).reshape(1, -1),
                        np.asarray(pred_mid_str).reshape(1, -1),
                        np.asarray(pred_bot_str).reshape(1, -1)
                    ],
                                                 axis=0).reshape(3, -1)

                    for i in random_idx:
                        label_sample = label_concat[:, i]
                        pred_sample = pred_concat[:, i]

                        gt_str = utils.str_combine(label_sample[0],
                                                   label_sample[1],
                                                   label_sample[2])
                        pred_str = utils.str_combine(pred_sample[0],
                                                     pred_sample[1],
                                                     pred_sample[2])
                        predicted_result_log += f'{gt_str:25s} | {pred_str:25s} | \t{str(pred_str == gt_str)}\n'
                    predicted_result_log += f'{dashed_line}'
                    print(predicted_result_log)
                    log.write(predicted_result_log + '\n')

                # Stochastic weight averaging
                optimizer.update_swa()
                swa_count += 1
                if swa_count % 5 == 0:
                    optimizer.swap_swa_sgd()
                    torch.save(
                        model.module.state_dict(),
                        f'./models/{opt.experiment_name}/swa_{swa_count}.pth')

        if (n_epoch) % 5 == 0:
            torch.save(model.module.state_dict(),
                       f'./models/{opt.experiment_name}/{n_epoch}.pth')
Esempio n. 11
0
class TPUFitter:
    
    def __init__(self, model, device, config, base_model_path='/', model_name='unnamed', model_prefix='roberta', model_version='v1', out_path='/', log_path='/'):
        self.log_path = Path(log_path, 'log').with_suffix('.txt')
        self.log(f'TPUFitter started to initilized.', direct_out=True)
        self.config = config
        self.epoch = 0
        self.base_model_path = base_model_path
        self.model_name = model_name
        self.model_version = model_version
        self.model_path = Path(self.base_model_path, self.model_name, self.model_version)
        
        self.out_path = out_path
        self.node_path = Path(self.out_path, 'node_submissions')
        self.create_dir_structure()

        self.model = model
        self.device = device
        # whether use stochastic weight avaraging
        self.use_SWA = config.use_SWA
        # whether use different lr for backbone and classifier head
        self.use_diff_lr = config.use_diff_lr
        
        self._set_optimizer_scheduler()
        self.criterion = config.criterion
        self.best_score = -1.0
        self.log(f'Fitter prepared. Device is {self.device}', direct_out=True)
    
    def create_dir_structure(self):
        self.node_path.mkdir(parents=True, exist_ok=True)
        self.log(f'**** Directory structure created ****', direct_out=True)
    
    def _set_optimizer_scheduler(self):
        self.log(f'Optimizer and scheduler started to initilized.', direct_out=True)
        def is_backbone(n):
            return 'backbone' in n

        param_optimizer = list(self.model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']

        # use different learning rate for backbone transformer and classifier head
        if self.use_diff_lr:
            backbone_lr, head_lr = self.config.lr*xm.xrt_world_size(), self.config.lr*xm.xrt_world_size()*500
            optimizer_grouped_parameters = [
                # {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001},
                # {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
                {"params": [p for n, p in param_optimizer if is_backbone(n)], "lr": backbone_lr},
                {"params": [p for n, p in param_optimizer if not is_backbone(n)], "lr": head_lr}
            ]
            self.log(f'Different Learning rate for backbone: {backbone_lr} head:{head_lr}')
        else:
            optimizer_grouped_parameters = [
                {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001},
                {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
                ]
        
        try:
            self.optimizer = AdamW(optimizer_grouped_parameters, lr=self.config.lr*xm.xrt_world_size())
            # self.optimizer = SGD(optimizer_grouped_parameters, lr=self.config.lr*xm.xrt_world_size(), momentum=0.9)
        except:
            param_g_1 = [p for n, p in param_optimizer if is_backbone(n)]
            param_g_2 = [p for n, p in param_optimizer if not is_backbone(n)]
            param_intersect = list(set(param_g_1) & set(param_g_2))
            self.log(f'intersect: {param_intersect}', direct_out=True)

        if self.use_SWA:
            self.optimizer = SWA(self.optimizer)
        
        if 'num_training_steps' in self.config.scheduler_params:
            num_training_steps = int(self.config.train_lenght / self.config.batch_size / xm.xrt_world_size() * self.config.n_epochs)
            self.log(f'Number of training steps: {num_training_steps}', direct_out=True)
            self.config.scheduler_params['num_training_steps'] = num_training_steps
        
        self.scheduler = self.config.SchedulerClass(self.optimizer, **self.config.scheduler_params)

    def fit(self, train_loader, validation_loader, n_epochs=None):
        self.log(f'**** Fitting process has been started ****', direct_out=True)
        if n_epochs is None:
            n_epochs = self.config.n_epochs
        
        for e in range(n_epochs):
            if self.config.verbose:
                lr = self.optimizer.param_groups[0]['lr']
                timestamp = datetime.utcnow().isoformat()
                self.log(f'\n{timestamp}\nLR: {lr} \nEpoch:{e}')

            t = time.time()
            para_loader = pl.ParallelLoader(train_loader, [self.device])
            losses, final_scores = self.train_one_epoch(para_loader.per_device_loader(self.device), e)
            
            self.log(f'[RESULT]: Train. Epoch: {self.epoch}, loss: {losses.avg:.5f}, final_score: {final_scores.avg:.5f}, time: {(time.time() - t):.5f}')

            t = time.time()
            para_loader = pl.ParallelLoader(validation_loader, [self.device])
            
            # swap SWA weights for validation
            if self.use_SWA:
                self.log('Swapping SWA weights for validation', direct_out=True)
                self.optimizer.swap_swa_sgd()
            
            losses, final_scores, threshold = self.validation(para_loader.per_device_loader(self.device))
            self.log(f'[RESULT]: Validation. Epoch: {self.epoch}, loss: {losses.avg:.5f}, final_score: {final_scores.avg:.5f}, best_th: {threshold.find:.3f}, time: {(time.time() - t):.5f}')
            # swap back to normal weights to continue training
            if self.use_SWA:
                self.log('Swapping back to original weights for validation', direct_out=True)
                self.optimizer.swap_swa_sgd()
            
            if final_scores.avg > self.best_score:
                self.best_score = final_scores.avg
                self.save('best_model')
                self.log('Best model has been updated', direct_out=True)
                # after one epoch, update SWA model if validation score is increased
                if self.use_SWA:
                    self.optimizer.update_swa()
                    self.log('SWA model weights have been updated', direct_out=True)

            if self.config.validation_scheduler:
                # self.scheduler.step(metrics=final_scores.avg)
                self.scheduler.step()
            
            self.epoch += 1
    
    def run_tuning_and_inference(self, test_loader, validation_loader, validation_tune_loader, n_epochs):
        self.log('******Validation tuning and inference is started*****', direct_out=True)
        self.run_validation_tuning(validation_loader, validation_tune_loader, n_epochs)
        para_loader = pl.ParallelLoader(test_loader, [self.device])
        self.run_inference(para_loader.per_device_loader(self.device))
    
    def run_validation_tuning(self, validation_loader, validation_tune_loader, n_epochs):
        self.log('******Validation tuning is started*****', direct_out=True)
        # self.optimizer.param_groups[0]['lr'] = self.config.lr*xm.xrt_world_size() / (epoch + 1)
        self.fit(validation_tune_loader, validation_loader, n_epochs)
    
    def validation(self, val_loader):
        self.log(f'**** Validation process has been started ****', direct_out=True)
        self.model.eval()
        losses = AverageMeter()
        final_scores = RocAucMeter()
        threshold = ThresholdMeter()

        t = time.time()
        for step, (targets, inputs, attention_masks) in enumerate(val_loader):
            self.log(
                f'Valid Step {step}, loss: ' + \
                f'{losses.avg:.5f}, final_score: {final_scores.avg:.5f}, ' + \
                f'time: {(time.time() - t):.5f}', step=step
            )
            with torch.no_grad():
                inputs = inputs.to(self.device, dtype=torch.long) 
                attention_masks = attention_masks.to(self.device, dtype=torch.long) 
                targets = targets.to(self.device, dtype=torch.float) 

                outputs = self.model(inputs, attention_masks)
                loss = self.criterion(outputs, targets)
                
                batch_size = inputs.size(0)

                final_scores.update(targets, outputs)
                losses.update(loss.detach().item(), batch_size)
                threshold.update(targets, outputs)
        
        return losses, final_scores, threshold

    def train_one_epoch(self, train_loader, epoch):
        self.log(f'**** Epoch training has started: {epoch} ****', direct_out=True)
        self.model.train()

        losses = AverageMeter()
        final_scores = RocAucMeter()
        t = time.time()
        for step, (targets, inputs, attention_masks) in enumerate(train_loader):
            self.log(
                f'Train Step {step}, loss: ' + \
                f'{losses.avg:.5f}, final_score: {final_scores.avg:.5f}, ' + \
                f'time: {(time.time() - t):.5f}', step=step
            )

            inputs = inputs.to(self.device, dtype=torch.long)
            attention_masks = attention_masks.to(self.device, dtype=torch.long)
            targets = targets.to(self.device, dtype=torch.float)

            self.optimizer.zero_grad()

            outputs = self.model(inputs, attention_masks)
            loss = self.criterion(outputs, targets)

            batch_size = inputs.size(0)
            
            final_scores.update(targets, outputs)
            losses.update(loss.detach().item(), batch_size)

            loss.backward()
            xm.optimizer_step(self.optimizer)

            if self.config.step_scheduler:
                self.scheduler.step()
        
        return losses, final_scores

    def run_inference(self, test_loader):
        self.log(f'**** Inference process has been started ****', direct_out=True)
        self.model.eval()
        result = {'id': [], 'toxic': []}
        
        t = time.time()
        for step, (ids, inputs, attention_masks) in enumerate(test_loader):
            self.log(f'Prediction Step {step}, time: {(time.time() - t):.5f}', step=step)

            with torch.no_grad():
                inputs = inputs.to(self.device, dtype=torch.long) 
                attention_masks = attention_masks.to(self.device, dtype=torch.long)
                outputs = self.model(inputs, attention_masks)
                toxics = torch.nn.functional.softmax(outputs, dim=1).data.cpu().numpy()[:,1]

            result['id'].extend(ids.cpu().numpy())
            result['toxic'].extend(toxics)

        result = pd.DataFrame(result)
        print(f'Node path is: {self.node_path}')
        node_count = len(list(self.node_path.glob('*.csv')))
        result.to_csv(self.node_path/f'submission_{node_count}_{datetime.utcnow().microsecond}_{random.random()}.csv', index=False)

    def run_pseudolabeling(self, test_loader, epoch):
        losses = AverageMeter()
        final_scores = RocAucMeter()

        self.model.eval()
        
        t = time.time()
        for step, (ids, inputs, attention_masks) in enumerate(test_loader):

            inputs = inputs.to(self.device, dtype=torch.long) 
            attention_masks = attention_masks.to(self.device, dtype=torch.long)
            outputs = self.model(inputs, attention_masks)
            # print(f'Inputs: {inputs} size: {inputs.size()}')
            # print(f'outputs: {outputs} size: {outputs.size()}')
            toxics = torch.nn.functional.softmax(outputs, dim=1)[:,1]
            toxic_mask = (toxics<=0.4) | (toxics>=0.8)
            # print(attention_masks.size())
            toxics = toxics[toxic_mask]
            inputs = inputs[toxic_mask]
            attention_masks = attention_masks[toxic_mask]
            # print(f'toxics: {toxics.size()}')
            # print(f'inputs: {inputs.size()}')
            if toxics.nelement() != 0:
                targets_int = (toxics>self.config.pseudolabeling_threshold).int()
                targets = torch.stack([onehot(2, target) for target in targets_int])
                # print(targets_int)
                
                self.model.train()
                self.log(
                    f'Pseudolabeling Step {step}, loss: ' + \
                    f'{losses.avg:.5f}, final_score: {final_scores.avg:.5f}, ' + \
                    f'time: {(time.time() - t):.5f}', step=step
                )
    
                targets = targets.to(self.device, dtype=torch.float)
    
                self.optimizer.zero_grad()
    
                outputs = self.model(inputs, attention_masks)
                loss = self.criterion(outputs, targets)
    
                batch_size = inputs.size(0)
                
                final_scores.update(targets, outputs)
                losses.update(loss.detach().item(), batch_size)
    
                loss.backward()
                xm.optimizer_step(self.optimizer)
    
                if self.config.step_scheduler:
                    self.scheduler.step()
    
        self.log(f'[RESULT]: Pseudolabeling. Epoch: {epoch}, loss: {losses.avg:.5f}, final_score: {final_scores.avg:.5f}, time: {(time.time() - t):.5f}')

    def get_submission(self, out_dir):
        submission = pd.concat([pd.read_csv(path) for path in (out_dir/'node_submissions').glob('*.csv')]).groupby('id').mean()
        return submission
    
    def save(self, name):
        self.model_path.mkdir(parents=True, exist_ok=True)
        path = (self.model_path/name).with_suffix('.bin')
        
        if self.use_SWA:
            self.optimizer.swap_swa_sgd()

        xm.save(self.model.state_dict(), path)
        self.log(f'Model has been saved')

    def log(self, message, step=None, direct_out=False):
        if direct_out or self.config.verbose:
            if direct_out or step is None or (step is not None and step % self.config.verbose_step == 0):
                xm.master_print(message)
                with open(self.log_path, 'a+') as logger:
                    xm.master_print(f'{message}', logger)
Esempio n. 12
0
class Train(object):
    """Train class.
  """
    def __init__(self, ):

        trainds = AlaskaDataIter(cfg.DATA.root_path,
                                 cfg.DATA.train_txt_path,
                                 training_flag=True)
        self.train_ds = DataLoader(trainds,
                                   cfg.TRAIN.batch_size,
                                   num_workers=cfg.TRAIN.process_num,
                                   shuffle=True)

        valds = AlaskaDataIter(cfg.DATA.root_path,
                               cfg.DATA.val_txt_path,
                               training_flag=False)
        self.val_ds = DataLoader(valds,
                                 cfg.TRAIN.batch_size,
                                 num_workers=cfg.TRAIN.process_num,
                                 shuffle=False)

        self.init_lr = cfg.TRAIN.init_lr
        self.warup_step = cfg.TRAIN.warmup_step
        self.epochs = cfg.TRAIN.epoch
        self.batch_size = cfg.TRAIN.batch_size
        self.l2_regularization = cfg.TRAIN.weight_decay_factor

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else 'cpu')

        self.model = CenterNet().to(self.device)

        self.load_weight()

        if 'Adamw' in cfg.TRAIN.opt:

            self.optimizer = torch.optim.AdamW(
                self.model.parameters(),
                lr=self.init_lr,
                eps=1.e-5,
                weight_decay=self.l2_regularization)
        else:
            self.optimizer = torch.optim.SGD(
                self.model.parameters(),
                lr=self.init_lr,
                momentum=0.9,
                weight_decay=self.l2_regularization)

        if cfg.TRAIN.SWA > 0:
            ##use swa
            self.optimizer = SWA(self.optimizer)

        if cfg.TRAIN.mix_precision:
            self.model, self.optimizer = amp.initialize(self.model,
                                                        self.optimizer,
                                                        opt_level="O1")

        self.model = nn.DataParallel(self.model)

        self.ema = EMA(self.model, 0.999)

        self.ema.register()
        ###control vars
        self.iter_num = 0

        # self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,mode='max', patience=3,verbose=True)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, self.epochs, eta_min=1.e-6)

        self.criterion = CenterNetLoss().to(self.device)

    def custom_loop(self):
        """Custom training and testing loop.
    Args:
      train_dist_dataset: Training dataset created using strategy.
      test_dist_dataset: Testing dataset created using strategy.
      strategy: Distribution strategy.
    Returns:
      train_loss, train_accuracy, test_loss, test_accuracy
    """
        def train_epoch(epoch_num):

            summary_loss_cls = AverageMeter()
            summary_loss_wh = AverageMeter()
            self.model.train()

            if cfg.MODEL.freeze_bn:
                for m in self.model.modules():
                    if isinstance(m, nn.BatchNorm2d):
                        m.eval()
                        if cfg.MODEL.freeze_bn_affine:
                            m.weight.requires_grad = False
                            m.bias.requires_grad = False
            for image, hm_target, wh_target, weights in self.train_ds:

                if epoch_num < 10:
                    ###excute warm up in the first epoch
                    if self.warup_step > 0:
                        if self.iter_num < self.warup_step:
                            for param_group in self.optimizer.param_groups:
                                param_group['lr'] = self.iter_num / float(
                                    self.warup_step) * self.init_lr
                                lr = param_group['lr']

                            logger.info('warm up with learning rate: [%f]' %
                                        (lr))

                start = time.time()

                if cfg.TRAIN.vis:
                    for i in range(image.shape[0]):

                        img = image[i].numpy()
                        img = np.transpose(img, axes=[1, 2, 0])
                        hm = hm_target[i].numpy()
                        wh = wh_target[i].numpy()

                        if cfg.DATA.use_int8_data:
                            hm = hm[:, :, 0].astype(np.uint8)
                            wh = wh[:, :, 0]
                        else:
                            hm = hm[:, :, 0].astype(np.float32)
                            wh = wh[:, :, 0].astype(np.float32)

                        cv2.namedWindow('s_hm', 0)
                        cv2.imshow('s_hm', hm)
                        cv2.namedWindow('s_wh', 0)
                        cv2.imshow('s_wh', wh + 1)
                        cv2.namedWindow('img', 0)
                        cv2.imshow('img', img)
                        cv2.waitKey(0)
                else:
                    data = image.to(self.device).float()

                    if cfg.DATA.use_int8_data:
                        hm_target = hm_target.to(
                            self.device).float() / cfg.DATA.use_int8_enlarge
                    else:
                        hm_target = hm_target.to(self.device).float()
                    wh_target = wh_target.to(self.device).float()
                    weights = weights.to(self.device).float()

                    batch_size = data.shape[0]

                    cls, wh = self.model(data)

                    cls_loss, wh_loss = self.criterion(
                        [cls, wh], [hm_target, wh_target, weights])

                    current_loss = cls_loss + wh_loss
                    summary_loss_cls.update(cls_loss.detach().item(),
                                            batch_size)
                    summary_loss_wh.update(wh_loss.detach().item(), batch_size)
                    self.optimizer.zero_grad()

                    if cfg.TRAIN.mix_precision:
                        with amp.scale_loss(current_loss,
                                            self.optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        current_loss.backward()

                    self.optimizer.step()
                    if cfg.TRAIN.ema:
                        self.ema.update()
                    self.iter_num += 1
                    time_cost_per_batch = time.time() - start

                    images_per_sec = cfg.TRAIN.batch_size * cfg.TRAIN.num_gpu / time_cost_per_batch

                    if self.iter_num % cfg.TRAIN.log_interval == 0:

                        log_message = '[TRAIN], '\
                                      'Epoch %d Step %d, ' \
                                      'summary_loss: %.6f, ' \
                                      'cls_loss: %.6f, '\
                                      'wh_loss: %.6f, ' \
                                      'time: %.6f, '\
                                      'speed %d images/persec'% (
                                          epoch_num,
                                          self.iter_num,
                                          summary_loss_cls.avg+summary_loss_wh.avg,
                                          summary_loss_cls.avg ,
                                          summary_loss_wh.avg,
                                          time.time() - start,
                                          images_per_sec)
                        logger.info(log_message)

                if cfg.TRAIN.SWA > 0 and epoch_num >= cfg.TRAIN.SWA:
                    self.optimizer.update_swa()

            return summary_loss_cls, summary_loss_wh

        def test_epoch(epoch_num):
            summary_loss_cls = AverageMeter()
            summary_loss_wh = AverageMeter()

            self.model.eval()
            t = time.time()
            with torch.no_grad():
                for step, (image, hm_target, wh_target,
                           weights) in enumerate(self.val_ds):

                    data = image.to(self.device).float()

                    if cfg.DATA.use_int8_data:
                        hm_target = hm_target.to(
                            self.device).float() / cfg.DATA.use_int8_enlarge
                    else:
                        hm_target = hm_target.to(self.device).float()

                    wh_target = wh_target.to(self.device).float()
                    weights = weights.to(self.device).float()
                    batch_size = data.shape[0]

                    with torch.no_grad():
                        cls, wh = self.model(data)

                    cls_loss, wh_loss = self.criterion(
                        [cls, wh], [hm_target, wh_target, weights])

                    summary_loss_cls.update(cls_loss.detach().item(),
                                            batch_size)
                    summary_loss_wh.update(wh_loss.detach().item(), batch_size)

                    if step % cfg.TRAIN.log_interval == 0:

                        log_message =   '[VAL], '\
                                        'Epoch %d Step %d, ' \
                                        'summary_loss: %.6f, ' \
                                        'cls_loss: %.6f, '\
                                        'wh_loss: %.6f, ' \
                                        'time: %.6f' % (epoch_num,
                                                        step,
                                                        summary_loss_cls.avg+summary_loss_wh.avg,
                                                        summary_loss_cls.avg,
                                                        summary_loss_wh.avg,
                                                        time.time() - t)

                        logger.info(log_message)

            return summary_loss_cls, summary_loss_wh

        for epoch in range(self.epochs):

            for param_group in self.optimizer.param_groups:
                lr = param_group['lr']
            logger.info('learning rate: [%f]' % (lr))
            t = time.time()

            summary_loss_cls, summary_loss_wh = train_epoch(epoch)

            train_epoch_log_message = '[centernet], '\
                                      '[RESULT]: Train. Epoch: %d,' \
                                      ' summary_loss: %.5f,' \
                                      ' cls_loss: %.6f, ' \
                                      ' wh_loss: %.6f, ' \
                                      ' time:%.5f' % (epoch,
                                                      summary_loss_cls.avg+summary_loss_wh.avg,
                                                      summary_loss_cls.avg,
                                                      summary_loss_wh.avg,
                                                      (time.time() - t))
            logger.info(train_epoch_log_message)

            if cfg.TRAIN.SWA > 0 and epoch >= cfg.TRAIN.SWA:

                ###switch to avg model
                self.optimizer.swap_swa_sgd()

            ##switch eam weighta
            if cfg.TRAIN.ema:
                self.ema.apply_shadow()

            if epoch % cfg.TRAIN.test_interval == 0:

                summary_loss_cls, summary_loss_wh = test_epoch(epoch)

                val_epoch_log_message = '[centernet], '\
                                        '[RESULT]: VAL. Epoch: %d,' \
                                        ' summary_loss: %.5f,' \
                                        ' cls_loss: %.6f, ' \
                                        ' wh_loss: %.6f, ' \
                                        ' time:%.5f' % (epoch,
                                                        summary_loss_cls.avg+summary_loss_wh.avg,
                                                        summary_loss_cls.avg,
                                                        summary_loss_wh.avg,
                                                        (time.time() - t))
                logger.info(val_epoch_log_message)

            self.scheduler.step()
            # self.scheduler.step(final_scores.avg)

            #### save model
            if not os.access(cfg.MODEL.model_path, os.F_OK):
                os.mkdir(cfg.MODEL.model_path)

            #### save the model every end of epoch
            current_model_saved_name = './model/centernet_epoch_%d_val_loss%.6f.pth' % (
                epoch, summary_loss_cls.avg + summary_loss_wh.avg)

            logger.info('A model saved to %s' % current_model_saved_name)
            torch.save(self.model.module.state_dict(),
                       current_model_saved_name)

            ####switch back
            if cfg.TRAIN.ema:
                self.ema.restore()

            if cfg.TRAIN.SWA > 0 and epoch > cfg.TRAIN.SWA:
                ###switch back to plain model to train next epoch
                self.optimizer.swap_swa_sgd()

    def load_weight(self):
        if cfg.MODEL.pretrained_model is not None:
            state_dict = torch.load(cfg.MODEL.pretrained_model,
                                    map_location=self.device)
            self.model.load_state_dict(state_dict, strict=False)