Exemplo n.º 1
0
class Train(object):
    """Train class.
  """
    def __init__(self, train_ds, val_ds, fold):
        self.fold = fold

        self.init_lr = cfg.TRAIN.init_lr
        self.warup_step = cfg.TRAIN.warmup_step
        self.epochs = cfg.TRAIN.epoch
        self.batch_size = cfg.TRAIN.batch_size
        self.l2_regularization = cfg.TRAIN.weight_decay_factor

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else 'cpu')

        self.model = Net().to(self.device)

        self.load_weight()

        param_optimizer = list(self.model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']

        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            cfg.TRAIN.weight_decay_factor
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]

        if 'Adamw' in cfg.TRAIN.opt:

            self.optimizer = torch.optim.AdamW(self.model.parameters(),
                                               lr=self.init_lr,
                                               eps=1.e-5)
        else:
            self.optimizer = torch.optim.SGD(self.model.parameters(),
                                             lr=0.001,
                                             momentum=0.9)

        if cfg.TRAIN.SWA > 0:
            ##use swa
            self.optimizer = SWA(self.optimizer)

        if cfg.TRAIN.mix_precision:
            self.model, self.optimizer = amp.initialize(self.model,
                                                        self.optimizer,
                                                        opt_level="O1")

        self.ema = EMA(self.model, 0.999)

        self.ema.register()
        ###control vars
        self.iter_num = 0

        self.train_ds = train_ds

        self.val_ds = val_ds

        # self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,mode='max', patience=3,verbose=True)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, self.epochs, eta_min=1.e-6)

        self.criterion = nn.BCEWithLogitsLoss().to(self.device)

    def custom_loop(self):
        """Custom training and testing loop.
    Args:
      train_dist_dataset: Training dataset created using strategy.
      test_dist_dataset: Testing dataset created using strategy.
      strategy: Distribution strategy.
    Returns:
      train_loss, train_accuracy, test_loss, test_accuracy
    """
        def distributed_train_epoch(epoch_num):

            summary_loss = AverageMeter()
            acc_score = ACCMeter()
            self.model.train()

            if cfg.MODEL.freeze_bn:
                for m in self.model.modules():
                    if isinstance(m, nn.BatchNorm2d):
                        m.eval()
                        if cfg.MODEL.freeze_bn_affine:
                            m.weight.requires_grad = False
                            m.bias.requires_grad = False
            for step in range(self.train_ds.size):

                if epoch_num < 10:
                    ###excute warm up in the first epoch
                    if self.warup_step > 0:
                        if self.iter_num < self.warup_step:
                            for param_group in self.optimizer.param_groups:
                                param_group['lr'] = self.iter_num / float(
                                    self.warup_step) * self.init_lr
                                lr = param_group['lr']

                            logger.info('warm up with learning rate: [%f]' %
                                        (lr))

                start = time.time()

                images, data, target = self.train_ds()
                images = torch.from_numpy(images).to(self.device).float()
                data = torch.from_numpy(data).to(self.device).float()
                target = torch.from_numpy(target).to(self.device).float()

                batch_size = data.shape[0]

                output = self.model(images, data)

                current_loss = self.criterion(output, target)

                summary_loss.update(current_loss.detach().item(), batch_size)
                acc_score.update(target, output)
                self.optimizer.zero_grad()

                if cfg.TRAIN.mix_precision:
                    with amp.scale_loss(current_loss,
                                        self.optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    current_loss.backward()

                self.optimizer.step()
                if cfg.MODEL.ema:
                    self.ema.update()
                self.iter_num += 1
                time_cost_per_batch = time.time() - start

                images_per_sec = cfg.TRAIN.batch_size / time_cost_per_batch

                if self.iter_num % cfg.TRAIN.log_interval == 0:

                    log_message = '[fold %d], '\
                                  'Train Step %d, ' \
                                  'summary_loss: %.6f, ' \
                                  'accuracy: %.6f, ' \
                                  'time: %.6f, '\
                                  'speed %d images/persec'% (
                                      self.fold,
                                      step,
                                      summary_loss.avg,
                                      acc_score.avg,
                                      time.time() - start,
                                      images_per_sec)
                    logger.info(log_message)

            if cfg.TRAIN.SWA > 0 and epoch_num >= cfg.TRAIN.SWA:
                self.optimizer.update_swa()

            return summary_loss, acc_score

        def distributed_test_epoch(epoch_num):
            summary_loss = AverageMeter()
            acc_score = ACCMeter()

            self.model.eval()
            t = time.time()
            with torch.no_grad():
                for step in range(self.val_ds.size):
                    images, data, target = self.train_ds()
                    images = torch.from_numpy(images).to(self.device).float()
                    data = torch.from_numpy(data).to(self.device).float()
                    target = torch.from_numpy(target).to(self.device).float()
                    batch_size = data.shape[0]

                    output = self.model(images, data)
                    loss = self.criterion(output, target)

                    summary_loss.update(loss.detach().item(), batch_size)
                    acc_score.update(target, output)

                    if step % cfg.TRAIN.log_interval == 0:

                        log_message = '[fold %d], '\
                                      'Val Step %d, ' \
                                      'summary_loss: %.6f, ' \
                                      'acc: %.6f, ' \
                                      'time: %.6f' % (
                                      self.fold,step, summary_loss.avg, acc_score.avg, time.time() - t)

                        logger.info(log_message)

            return summary_loss, acc_score

        for epoch in range(self.epochs):

            for param_group in self.optimizer.param_groups:
                lr = param_group['lr']
            logger.info('learning rate: [%f]' % (lr))
            t = time.time()

            summary_loss, acc_score = distributed_train_epoch(epoch)

            train_epoch_log_message = '[fold %d], '\
                                      '[RESULT]: Train. Epoch: %d,' \
                                      ' summary_loss: %.5f,' \
                                      ' acuracy: %.5f,' \
                                      ' time:%.5f' % (
                                      self.fold,epoch, summary_loss.avg,acc_score.avg, (time.time() - t))
            logger.info(train_epoch_log_message)

            if cfg.TRAIN.SWA > 0 and epoch >= cfg.TRAIN.SWA:

                ###switch to avg model
                self.optimizer.swap_swa_sgd()

            ##switch eam weighta
            if cfg.MODEL.ema:
                self.ema.apply_shadow()

            if epoch % cfg.TRAIN.test_interval == 0:

                summary_loss, acc_score = distributed_test_epoch(epoch)

                val_epoch_log_message = '[fold %d], '\
                                        '[RESULT]: VAL. Epoch: %d,' \
                                        ' summary_loss: %.5f,' \
                                        ' accuracy: %.5f,' \
                                        ' time:%.5f' % (
                                         self.fold,epoch, summary_loss.avg,acc_score.avg, (time.time() - t))
                logger.info(val_epoch_log_message)

            self.scheduler.step()
            # self.scheduler.step(final_scores.avg)

            #### save model
            if not os.access(cfg.MODEL.model_path, os.F_OK):
                os.mkdir(cfg.MODEL.model_path)
            ###save the best auc model

            #### save the model every end of epoch
            current_model_saved_name = './models/fold%d_epoch_%d_val_loss%.6f.pth' % (
                self.fold, epoch, summary_loss.avg)

            logger.info('A model saved to %s' % current_model_saved_name)
            torch.save(self.model.state_dict(), current_model_saved_name)

            ####switch back
            if cfg.MODEL.ema:
                self.ema.restore()

            # save_checkpoint({
            #           'state_dict': self.model.state_dict(),
            #           },iters=epoch,tag=current_model_saved_name)

            if cfg.TRAIN.SWA > 0 and epoch > cfg.TRAIN.SWA:
                ###switch back to plain model to train next epoch
                self.optimizer.swap_swa_sgd()

    def load_weight(self):
        if cfg.MODEL.pretrained_model is not None:
            state_dict = torch.load(cfg.MODEL.pretrained_model,
                                    map_location=self.device)
            self.model.load_state_dict(state_dict, strict=False)
Exemplo n.º 2
0
class Trainer(object):
    def __init__(self,
                 args,
                 train_dataloader=None,
                 validate_dataloader=None,
                 test_dataloader=None):
        self.args = args
        self.train_dataloader = train_dataloader
        self.validate_dataloader = validate_dataloader
        self.test_dataloader = test_dataloader

        self.label_lst = [i for i in range(self.args.num_classes)]
        self.num_labels = self.args.num_classes

        self.config_class = AutoConfig
        self.model_class = BertForSequenceClassification

        self.config = self.config_class.from_pretrained(
            self.args.bert_model_name,
            num_labels=self.num_labels,
            finetuning_task='nsmc',
            id2label={str(i): label
                      for i, label in enumerate(self.label_lst)},
            label2id={label: i
                      for i, label in enumerate(self.label_lst)})
        self.model = self.model_class.from_pretrained(
            self.args.bert_model_name, config=self.config)
        self.optimizer = None
        self.scheduler = None

        # GPU or CPU
        self.device = "cuda" if torch.cuda.is_available(
        ) and args.cuda else "cpu"
        self.model.to(self.device)

    def train(self, alpha, gamma):
        train_dataloader = self.train_dataloader

        t_total = len(train_dataloader) * self.args.num_epochs

        # Prepare optimizer and schedule (linear warmup and decay)
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in self.model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0
        }, {
            'params': [
                p for n, p in self.model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0
        }]

        if self.args.use_swa:
            base_opt = AdamW(optimizer_grouped_parameters,
                             lr=self.args.lr,
                             eps=1e-8)
            self.optimizer = SWA(base_opt,
                                 swa_start=4 * len(train_dataloader),
                                 swa_freq=100,
                                 swa_lr=5e-5)
            self.optimizer.param_groups = self.optimizer.optimizer.param_groups
            self.optimizer.state = self.optimizer.optimizer.state
            self.optimizer.defaults = self.optimizer.optimizer.defaults

        else:
            self.optimizer = optimizer = AdamW(optimizer_grouped_parameters,
                                               lr=self.args.lr,
                                               eps=1e-8)
        self.scheduler = scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=100,
            num_training_steps=self.args.num_epochs * len(train_dataloader))
        self.criterion = FocalLoss(alpha=alpha, gamma=gamma)

        # Train!
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d",
                    len(self.train_dataloader) * self.args.batch_size)
        logger.info("  Num Epochs = %d", self.args.num_epochs)
        logger.info("  Total train batch size = %d", self.args.batch_size)
        logger.info("  Total optimization steps = %d", t_total)

        global_step = 0
        tr_loss = 0.0
        self.model.zero_grad()
        self.optimizer.zero_grad()

        train_iterator = trange(int(self.args.num_epochs), desc="Epoch")

        fin_result = None
        f1_max = 0.0
        self.model.train()

        for epoch in train_iterator:
            epoch_iterator = tqdm(train_dataloader, desc="Iteration")
            for step, batch in enumerate(epoch_iterator):

                batch = tuple(t.to(self.device) for t in batch)  # GPU or CPU
                inputs = {
                    'input_ids': batch[0],
                    'attention_mask': batch[1],
                    'labels': batch[3],
                    'token_type_ids': batch[2]
                }

                # outputs = self.model(**inputs)
                # loss = outputs[0]

                # # Custom Loss
                loss, logits = self.model(**inputs)
                logits = torch.sigmoid(logits)

                labels = torch.zeros(
                    (len(batch[3]), self.num_labels)).to(self.device)
                labels[range(len(batch[3])), batch[3]] = 1

                loss = self.criterion(logits, labels)

                loss.backward()
                self.optimizer.step()
                self.scheduler.step()  # Update learning rate schedule

                self.model.zero_grad()
                self.optimizer.zero_grad()

                tr_loss += loss.item()
                global_step += 1
                logger.info('train loss %f', loss.item())

            logger.info('total train loss %f', tr_loss / global_step)
            if epoch >= 4 and self.args.use_swa:
                self.optimizer.swap_swa_sgd()

            fin_result = self.evaluate("validate")
            self.save_model(epoch)
            self.model.train()
            if epoch >= 4 and self.args.use_swa:
                self.optimizer.swap_swa_sgd()

            f1_max = max(fin_result['f1_macro'], f1_max)

        if epoch >= 4 and self.args.use_swa:
            self.optimizer.swap_swa_sgd()
        with open(os.path.join(self.args.base_dir, self.args.result_dir,
                               self.args.train_id, 'param_seach.txt'),
                  "a",
                  encoding="utf-8") as f:
            f.write('alpha: {}, gamma: {}, f1_macro: {}\n'.format(
                alpha, gamma, f1_max))
        return f1_max

    def evaluate(self, mode='test'):
        if mode == 'test':
            dataloader = self.test_dataloader
        elif mode == 'validate':
            dataloader = self.validate_dataloader
        else:
            raise Exception("Only dev and test dataset available")

        # Eval!
        logger.info("***** Running evaluation on %s dataset *****", mode)
        logger.info("  Num examples = %d",
                    len(dataloader) * self.args.batch_size)
        logger.info("  Batch size = %d", self.args.batch_size)
        eval_loss = 0.0
        nb_eval_steps = 0
        preds = None
        out_label_ids = None

        self.model.eval()

        for batch in tqdm(dataloader, desc="Evaluating"):
            batch = tuple(t.to(self.device) for t in batch)
            with torch.no_grad():
                inputs = {
                    'input_ids': batch[0],
                    'attention_mask': batch[1],
                    'labels': batch[3],
                    'token_type_ids': batch[2]
                }
                outputs = self.model(**inputs)
                tmp_eval_loss, logits = outputs[:2]

                eval_loss += tmp_eval_loss.mean().item()
            nb_eval_steps += 1

            if preds is None:
                preds = logits.detach().cpu().numpy()
                out_label_ids = inputs['labels'].detach().cpu().numpy()
            else:
                preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
                out_label_ids = np.append(
                    out_label_ids,
                    inputs['labels'].detach().cpu().numpy(),
                    axis=0)

        eval_loss = eval_loss / nb_eval_steps
        results = {"loss": eval_loss}

        preds = np.argmax(preds, axis=1)
        result = compute_metrics(preds, out_label_ids)
        results.update(result)

        p_macro, r_macro, f_macro, support_macro \
            = precision_recall_fscore_support(y_true=out_label_ids, y_pred=preds,
                                              labels=[i for i in range(self.num_labels)], average='macro')

        results.update({
            'precision': p_macro,
            'recall': r_macro,
            'f1_macro': f_macro
        })

        with open(self.args.prediction_file, "w", encoding="utf-8") as f:
            for pred in preds:
                f.write("{}\n".format(pred))

        if mode == 'validate':
            logger.info("***** Eval results *****")
            for key in sorted(results.keys()):
                logger.info("  %s = %s", key, str(results[key]))

        return results

    def save_model(self, num=0):
        state = {
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict()
        }
        torch.save(
            state,
            os.path.join(self.args.base_dir, self.args.result_dir,
                         self.args.train_id, 'epoch_' + str(num) + '.pth'))
        logger.info('model saved')

    def load_model(self, model_name):

        state = torch.load(
            os.path.join(self.args.base_dir, self.args.result_dir,
                         self.args.train_id, model_name))
        self.model.load_state_dict(state['model'])
        if self.optimizer is not None:
            self.optimizer.load_state_dict(state['optimizer'])
        if self.scheduler is not None:
            self.scheduler.load_state_dict(state['scheduler'])
        logger.info('model loaded')
Exemplo n.º 3
0
def main(args):
    # Parameters for toy data and experiments
    plt.figure(figsize=(6.4, 3.4))
    rng_seed = 42
    np.random.seed(rng_seed)

    wstar = torch.tensor([[0.973, 1.144]], dtype=torch.float)
    xs = torch.tensor(np.random.randn(50, 2), dtype=torch.float)
    labels = torch.mm(xs, wstar.T)

    p = torch.tensor(np.random.uniform(0.05, 1.0, xs.shape[0]),
                     dtype=torch.float)
    ips = 1.0 / p
    n_iters = 50
    plot_every = n_iters // 10
    arrow_width = 0.012
    legends = {}

    # The loss function we want to optimize
    def loss_fn(out, y, mult):
        l2loss = (out - y)**2.0
        logl2loss = torch.log(1.0 + (out - y)**2.0)
        return torch.mean(mult * l2loss)

    # IPS-weighted approach
    for color_index, lr in enumerate([0.01, 0.02, 0.03, 0.05,
                                      0.1]):  # 0.01, 0.03, 0.05, 0.1, 0.3
        color = "C%d" % (color_index + 2) if color_index > 0 else "C1"
        model = torch.nn.Linear(2, 1)
        optimizer = torch.optim.SGD(model.parameters(), lr=lr)
        optimizer = SWA(optimizer, swa_start=0, swa_freq=1, swa_lr=lr)
        with torch.no_grad():
            model.bias.zero_()
            model.weight.zero_()
        old_weights = np.copy(model.weight.data.numpy())
        np.random.seed(rng_seed + color_index + 1)
        for t in range(n_iters):
            i = np.random.randint(xs.shape[0])
            x = xs[i, :]
            y = labels[i]
            optimizer.zero_grad()
            o = model(x)
            l = loss_fn(o, y, ips[i])
            l.backward()
            optimizer.step()
            if t % plot_every == 0:
                optimizer.swap_swa_sgd()
                x, y = model.weight.data.numpy()[0]
                optimizer.swap_swa_sgd()
                ox, oy = old_weights[0]
                label = f"IPS-SGD ($\\eta={lr}$)"
                arr = plt.arrow(ox,
                                oy,
                                x - ox,
                                y - oy,
                                width=arrow_width,
                                length_includes_head=True,
                                color=color,
                                label=label)
                optimizer.swap_swa_sgd()
                old_weights = np.copy(model.weight.data.numpy())
                optimizer.swap_swa_sgd()
                legends[label] = arr

    # Sample based approach
    for lr in [10.0]:
        # lr = 3.0 # 1.0
        model = torch.nn.Linear(2, 1)
        optimizer = torch.optim.SGD(model.parameters(), lr=lr)
        optimizer = SWA(optimizer, swa_start=0, swa_freq=1, swa_lr=lr)
        with torch.no_grad():
            model.bias.zero_()
            model.weight.zero_()
        old_weights = np.copy(model.weight.data.numpy())
        sample_probs = np.array(ips / torch.sum(ips))
        Mbar = float(np.mean(sample_probs))
        np.random.seed(rng_seed - 1)
        for t in range(n_iters):
            i = np.argwhere(np.random.multinomial(1, sample_probs) == 1.0)[0,
                                                                           0]
            x = xs[i, :]
            y = labels[i]
            optimizer.zero_grad()
            o = model(x)
            l = loss_fn(o, y, Mbar)
            l.backward()
            optimizer.step()
            if t % plot_every == 0:
                optimizer.swap_swa_sgd()
                x, y = model.weight.data.numpy()[0]
                optimizer.swap_swa_sgd()
                ox, oy = old_weights[0]
                label = f"\\textsc{{CounterSample}} ($\\eta={lr}$)"
                arr = plt.arrow(ox,
                                oy,
                                x - ox,
                                y - oy,
                                width=arrow_width,
                                length_includes_head=True,
                                color="C2",
                                label=label)

                optimizer.swap_swa_sgd()
                old_weights = np.copy(model.weight.data.numpy())
                optimizer.swap_swa_sgd()
                legends[label] = arr

    # True IPS-weighted loss over all datapoints, used for plotting contour
    def f(x1, x2):
        w = torch.tensor([[x1], [x2]], dtype=torch.float)
        o = torch.mm(xs, w)
        return float(loss_fn(o, torch.mm(xs, wstar.reshape((2, 1))), ips))

    # Compute all useful combinations of weights and compute true loss for each one
    # This will be used to compute a contour plot
    true_x1 = np.linspace(float(wstar[0, 0]) - 1.5,
                          float(wstar[0, 0]) + 0.8)  # - 1.5 / + 1.0
    true_x2 = np.linspace(float(wstar[0, 1]) - 1.5,
                          float(wstar[0, 1]) + 1.2)  # - 1.5 / + 1.0
    true_x1, true_x2 = np.meshgrid(true_x1, true_x2)
    true_y = np.array(
        [[f(true_x1[i1, i2], true_x2[i1, i2]) for i2 in range(len(true_x2))]
         for i1 in range(len(true_x1))])

    # Contour plot with optimum
    plt.plot(wstar[0, 0], wstar[0, 1], marker='o', markersize=3, color="black")
    plt.contour(true_x1, true_x2, true_y, 10, colors="black", alpha=0.35)

    # Generate legends from arrows and make figure
    def make_legend_arrow(legend, orig_handle, xdescent, ydescent, width,
                          height, fontsize):
        p = mpatches.FancyArrow(0,
                                0.5 * height,
                                width,
                                0,
                                length_includes_head=True,
                                head_width=0.75 * height)
        return p

    def sort_op(key):
        if "CounterSample" in key:
            return ""
        else:
            return key

    labels = [key for key in sorted(legends.keys(), key=sort_op)]
    arrows = [legends[key] for key in sorted(legends.keys(), key=sort_op)]
    plt.legend(arrows,
               labels,
               ncol=2,
               loc='upper center',
               framealpha=0.95,
               handler_map={
                   mpatches.FancyArrow:
                   HandlerPatch(patch_func=make_legend_arrow)
               },
               bbox_to_anchor=(0.5, 1.0 + 0.03))
    plt.xlabel("$w_1$")
    plt.ylabel("$w_2$")
    plt.tight_layout()
    plt.savefig(args.out, format=args.format)
Exemplo n.º 4
0
    # criterion = CTCLoss()
    #optimizer = torch.optim.Adam(model.parameters(), lr=0.0005)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
    scheduler = SWA(optimizer, swa_start=10, swa_freq=5, swa_lr=0.0005)
    #scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1, last_epoch=-1)
    best_loss = 1000.0
    best_acc = 0.63

    # 是否使用GPU
    torch.cuda.current_device()
    use_cuda = True
    if use_cuda:
        model = model.cuda()
    for epoch in range(50):
        #scheduler.step()
        scheduler.zero_grad()
        train_loss = train(train_loader, model, criterion, optimizer, epoch)
        val_loss = validate(val_loader, model, criterion)
        scheduler.step()

        val_label = [
            ''.join(map(str, x)) for x in val_loader.dataset.img_label
        ]
        val_predict_label = predict(val_loader, model, 3)
        val_predict_label = np.vstack([
            val_predict_label[:, :11].argmax(1),
            val_predict_label[:, 11:22].argmax(1),
            val_predict_label[:, 22:33].argmax(1),
            val_predict_label[:, 33:44].argmax(1),
            val_predict_label[:, 44:55].argmax(1),
        ]).T
Exemplo n.º 5
0
    def train(self, train_inputs):
        config = self.config.fitting
        model = train_inputs['model']
        train_data = train_inputs['train_data']
        dev_data = train_inputs['dev_data']
        epoch_start = train_inputs['epoch_start']

        train_steps = int((len(train_data) + config.batch_size - 1) / config.batch_size)
        train_dataloader = DataLoader(train_data,
                                      batch_size=config.batch_size,
                                      collate_fn=self.get_collate_fn('train'),
                                      shuffle=True)
        params_lr = []
        for key, value in model.get_params().items():
            if key in config.lr:
                params_lr.append({"params": value, 'lr': config.lr[key]})
        optimizer = torch.optim.Adam(params_lr)
        optimizer = SWA(optimizer)

        early_stopping = EarlyStopping(model, ROOT_WEIGHT, mode='max', patience=3)
        learning_schedual = LearningSchedual(optimizer, config.epochs, config.end_epoch, train_steps, config.lr)

        aux = ModelAux(self.config, train_steps)
        moving_log = MovingData(window=100)

        ending_flag = False
        detach_flag = False
        swa_flag = False
        fgm = FGM(model)
        for epoch in range(epoch_start, config.epochs):
            for step, (inputs, targets, others) in enumerate(train_dataloader):
                inputs = dict([(key, value[0].cuda() if value[1] else value[0]) for key, value in inputs.items()])
                targets = dict([(key, value.cuda()) for key, value in targets.items()])
                if epoch > 0 and step == 0:
                    model.detach_ptm(False)
                    detach_flag = False
                if epoch == 0 and step == 0:
                    model.detach_ptm(True)
                    detach_flag = True
                # train ================================================================================================
                preds = model(inputs, en_decode=config.verbose)
                loss = model.cal_loss(preds, targets, inputs['mask'])
                loss['back'].backward()

                # 对抗训练
                if (not detach_flag) and config.en_fgm:
                    fgm.attack(emb_name='word_embeddings')  # 在embedding上添加对抗扰动
                    preds_adv = model(inputs, en_decode=False)
                    loss_adv = model.cal_loss(preds_adv, targets, inputs['mask'])
                    loss_adv['back'].backward()  # 反向传播,并在正常的grad基础上,累加对抗训练的梯度
                    fgm.restore(emb_name='word_embeddings')  # 恢复embedding参数

                # torch.nn.utils.clip_grad_norm(model.parameters(), 1)
                optimizer.step()
                optimizer.zero_grad()
                with torch.no_grad():
                    logs = {}
                    if config.verbose:
                        pred_entity_point = model.find_entity(preds['pred'], others['raw_text'])
                        cn, pn, tn = self.calculate_f1(pred_entity_point, others['raw_entity'])
                        metrics_data = {'loss': loss['show'].cpu().numpy(), 'sampled_num': 1,
                                        'correct_num': cn, 'pred_num': pn,
                                        'true_num': tn}
                        moving_data = moving_log(epoch * train_steps + step, metrics_data)
                        logs['loss'] = moving_data['loss'] / moving_data['sampled_num']
                        logs['precise'], logs['recall'], logs['f1'] = calculate_f1(moving_data['correct_num'],
                                                                                   moving_data['pred_num'],
                                                                                   moving_data['true_num'],
                                                                                   verbose=True)
                    else:
                        metrics_data = {'loss': loss['show'].cpu().numpy(), 'sampled_num': 1}
                        moving_data = moving_log(epoch * train_steps + step, metrics_data)
                        logs['loss'] = moving_data['loss'] / moving_data['sampled_num']
                    # update lr
                    lr_data = learning_schedual.update_lr(epoch, step)
                    logs.update(lr_data)

                    if step + 1 == train_steps:
                        model.eval()
                        aux.new_line()

                        # dev ==========================================================================================

                        eval_inputs = {'model': model,
                                       'data': dev_data,
                                       'type_data': 'dev',
                                       'outfile': train_inputs['dev_res_file']}
                        dev_result = self.eval(eval_inputs)
                        logs['dev_loss'] = dev_result['loss']
                        logs['dev_precise'] = dev_result['precise']
                        logs['dev_recall'] = dev_result['recall']
                        logs['dev_f1'] = dev_result['f1']
                        if logs['dev_f1'] > 0.80:
                            torch.save(model.state_dict(),
                                       "{}/auto_save_{:.6f}.ckpt".format(ROOT_WEIGHT, logs['dev_f1']))
                        if (epoch > 3 or swa_flag) and config.en_swa:
                            optimizer.update_swa()
                            swa_flag = True
                        early_stop, best_score = early_stopping(logs['dev_f1'])

                        # test =========================================================================================
                        if (epoch + 1 == config.epochs and step + 1 == train_steps) or early_stop:
                            ending_flag = True
                            if swa_flag:
                                optimizer.swap_swa_sgd()
                                optimizer.bn_update(train_dataloader, model)

                        model.train()
                aux.show_log(epoch, step, logs)
                if ending_flag:
                    return best_score
Exemplo n.º 6
0
class Trainer:
    def __init__(
            self,
            net: ModuloNet,
            metrics=["cohen_kappa", "f1", "accuracy"],
            epochs=30,
            metric_to_maximize="accuracy",
            patience=None,
            batch_size=32,
            save_folder=None,
            loss=None,
            regularization=None,
            swa=None,
            optimizer=None,
            num_workers=0,
            net_methods=None
    ):
        if optimizer is None:
            optimizer = {'type': 'adam', 'args': {'lr': 1e-3}}
        if loss is None:
            loss = {'type': 'cross_entropy', 'args': {}}

        self.net_methods = net_methods if net_methods is not None else []
        print('METHODS')
        print(self.net_methods)
        self.net = net
        print('####################')
        print("Device: ", net.device)
        print('Using:', num_workers, ' workers')
        print('Trainable params', sum(p.numel() for p in net.parameters() if p.requires_grad))
        print('Total params', sum(p.numel() for p in net.parameters() if p.requires_grad))

        print('####################')

        self.loss_function = loss_functions[loss['type']](**loss['args'])
        self.optimizer_params = optimizer
        self.swa_params = swa
        self.regularization = []
        if regularization is not None:
            for regularizer in regularization:
                self.regularization += [
                    regularizers[regularizer['type']](self.net, **regularizer['args'])]

        self.reset_optimizer()

        self.metrics = {
            score: score_function for score, score_function in score_functions.items()
            if score in metrics + [metric_to_maximize]
        }

        self.iterations = 0
        self.epochs = epochs
        self.metric_to_maximize = metric_to_maximize
        self.patience = patience if patience else epochs
        self.loss_values = []
        self.save_folder = save_folder
        self.batch_size = batch_size
        self.num_workers = num_workers

    def reset_optimizer(self):
        self.base_optimizer = optimizers[self.optimizer_params['type']](self.net.parameters(),
                                                                        **self.optimizer_params[
                                                                            'args'])

        if self.swa_params is not None:
            self.optimizer = SWA(self.base_optimizer, **self.swa_params)
            self.swa = True
            self.averaged_weights = False
        else:
            self.optimizer = self.base_optimizer
            self.swa = False

    def on_batch_start(self):
        pass

    def on_epoch_end(self):
        pass

    def validate(self, validation_dataset, return_metrics_per_records=False, verbose=False):
        self.net.eval()
        if self.swa:
            self.optimizer.swap_swa_sgd()
            self.averaged_weights = not self.averaged_weights

        metrics_epoch = {
            metric: []
            for metric in self.metrics.keys()
        }
        metrics_per_records = {}
        hypnograms = {}
        predictions = self.net.predict_on_dataset(validation_dataset, return_prob=False,
                                                  verbose=verbose)
        record_weights = []

        for record in validation_dataset.records:
            metrics_per_records[record] = {}
            hypnogram_target = validation_dataset.hypnogram[record]
            hypnogram_predicted = predictions[record]

            hypnograms[os.path.split(record)[-2]] = {
                'predicted': hypnogram_predicted.astype(int).tolist(),
                'target': hypnogram_target.astype(int).tolist()}
            record_weights += [np.sum(hypnogram_target >= 0)]
            for metric, metric_function in self.metrics.items():
                metric_value = metric_function(hypnogram_target, hypnogram_predicted)
                metrics_per_records[record][metric] = metric_value
                metrics_epoch[metric].append(metric_value)

        record_weights = np.array(record_weights)
        for metric in metrics_epoch.keys():
            metrics_epoch[metric] = np.array(metrics_epoch[metric])
            record_weights_tp = record_weights[~np.isnan(metrics_epoch[metric])]
            metrics_epoch[metric] = metrics_epoch[metric][~np.isnan(metrics_epoch[metric])]

            try:
                metrics_epoch[metric] = np.average(metrics_epoch[metric], weights=record_weights_tp)
            except ZeroDivisionError:
                metrics_epoch[metric] = np.nan

            if self.metric_to_maximize == metric:
                value = metrics_epoch[metric]

        if self.swa:
            if self.averaged_weights:
                self.optimizer.swap_swa_sgd()
                self.averaged_weights = not self.averaged_weights

        if return_metrics_per_records:

            return metrics_epoch, value, metrics_per_records, hypnograms
        else:
            return metrics_epoch, value

    def train_on_batch(self, data, mask=-1):
        # 1. train network
        # Set network in train mode
        self.net.train()

        # Retrieve inputs
        args, hypnogram = self.net.get_args(data)
        device = self.net.device
        hypnogram = hypnogram.to(device)
        mask = [hypnogram != mask]
        hypnogram = hypnogram[mask]

        # zero the network parameters gradien
        self.optimizer.zero_grad()

        # forward + backward
        output = self.net.forward(*args)[0]
        output = output[mask]
        loss_train = self.loss_function(output, hypnogram)
        if self.regularization is not None:
            for regularizer in self.regularization:
                regularizer.regularized_all_param(loss_train)

        loss_train.backward()
        if isinstance(hypnogram, tuple):
            hypnogram = hypnogram[0]

        self.iterations += 1
        return output, loss_train, hypnogram

    def validate_on_batch(self, data, mask=-1):
        # 2. Evaluate network on validation
        # Set network in eval mode
        self.net.eval()
        # Retrieve inputs
        args, hypnogram = self.net.get_args(data)
        device = self.net.device
        hypnogram = hypnogram.to(device)
        mask = [hypnogram != mask]
        hypnogram = hypnogram[mask]
        # forward
        output = self.net.forward(*args)[0]
        output = output[mask]
        loss_validation = self.loss_function(output, hypnogram)
        if isinstance(hypnogram, tuple):
            hypnogram = hypnogram[0]

        return output, loss_validation, hypnogram

    def train(self, train_dataset, validation_dataset, verbose=1, reset_optimizer=True):
        """
        for epoch:
            for batch on train set:
                train net with optimizer SGD
                eval on a random batch of validation set
                print metrics on train set and val set every 1% of dataset
            Evaluate metrics BY RECORD, take mean
            if metric_to_maximize value > best_value:
                store best_net*
            else:
                patience += 1
            if patience to big:
                return
        """
        if reset_optimizer:
            self.reset_optimizer()

        if self.save_folder:
            self.save_weights('best_net')

        dataloader_train = DataLoader(
            train_dataset, shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers,
            pin_memory=True
        )

        metrics_final = {
            metric: 0
            for metric in self.metrics.keys()
        }

        best_value = 0
        counter_patience = 0
        for epoch in range(0, self.epochs):
            if verbose == 0:
                print('EPOCH:', epoch)
            # init running_loss
            running_loss_train_epoch = 0
            running_metrics = {metric: 0 for metric in self.metrics.keys()}
            buffer_outputs_train = ([], [])
            if verbose > 0:
                # Configurate progress bar
                t = tqdm(dataloader_train, 0)
                t.set_description("EPOCH {}".format(epoch))
                update_postfix_every = max(int(len(t) * 0.05), 1)
                counter_update_postfix = 0
            else:
                t = dataloader_train

            t_start_train = time.time()
            for i, data in enumerate(t):
                self.on_batch_start()

                if verbose > 0:
                    if (i + 1) % update_postfix_every == 0 and i != 0:
                        # compute desired metrics each update_postfix_every
                        for metric_name, metric_function in self.metrics.items():
                            running_metrics[metric_name] += metric_function(
                                buffer_outputs_train[0], buffer_outputs_train[1]
                            )
                        buffer_outputs_train = ([], [])
                        counter_update_postfix += 1
                        t.set_postfix(
                            loss=running_loss_train_epoch / (i + 1),
                            **{
                                k: v / counter_update_postfix
                                for k, v in running_metrics.items()
                            }
                        )
                        self.loss_values.append((running_loss_train_epoch, i + 1))

                # train
                output, loss_train, hypnogram = self.train_on_batch(data)

                # fill metrics for print
                running_loss_train_epoch += loss_train.item()
                buffer_outputs_train[0].extend(list(output.max(1)[1].cpu().numpy()))
                buffer_outputs_train[1].extend(list(hypnogram.cpu().numpy().flatten()))

                # gradient descent
                self.optimizer.step()
            t_stop_train = time.time()

            t_start_validation = time.time()
            metrics_epoch, value = self.validate(validation_dataset=validation_dataset)
            t_stop_validation = time.time()

            metrics_epoch["training_duration"] = t_stop_train - t_start_train
            metrics_epoch["validation_duration"] = t_stop_validation - t_start_validation

            if self.save_folder:
                self.save_weights(str(epoch) + "_net")

                json.dump(metrics_epoch,
                          open(self.save_folder + str(epoch) + "_metrics_epoch.json", "w"))

            if value > best_value:
                print("New best {} !".format(self.metric_to_maximize), value)
                # best_net = copy.deepcopy(self.net)
                metrics_final = {
                    metric: metrics_epoch[metric]
                    for metric in self.metrics.keys()
                }
                best_value = value
                counter_patience = 0
                if self.save_folder:
                    self.save_weights('best_net')
                    json.dump(metrics_epoch,
                              open(self.save_folder + "metrics_best_epoch.json", "w"))
            else:
                counter_patience += 1

            if counter_patience > self.patience:
                break

            self.on_epoch_end()
        return metrics_final

    def save_weights(self, file_name):
        self.net.save(self.save_folder + file_name)
Exemplo n.º 7
0
class Optimizer:
    optimizer_cls = None
    optimizer = None
    parameters = None

    def __init__(self,
                 gradient_clipping,
                 swa_start=None,
                 swa_freq=None,
                 swa_lr=None,
                 **kwargs):
        self.gradient_clipping = gradient_clipping
        self.optimizer_kwargs = kwargs
        self.swa_start = swa_start
        self.swa_freq = swa_freq
        self.swa_lr = swa_lr

    def set_parameters(self, parameters):
        self.parameters = tuple(parameters)
        self.optimizer = self.optimizer_cls(self.parameters,
                                            **self.optimizer_kwargs)
        if self.swa_start is not None:
            from torchcontrib.optim import SWA
            assert self.swa_freq is not None, self.swa_freq
            assert self.swa_lr is not None, self.swa_lr
            self.optimizer = SWA(self.optimizer,
                                 swa_start=self.swa_start,
                                 swa_freq=self.swa_freq,
                                 swa_lr=self.swa_lr)

    def check_if_set(self):
        assert self.optimizer is not None, \
            'The optimizer is not initialized, call set_parameter before' \
            ' using any of the optimizer functions'

    def zero_grad(self):
        self.check_if_set()
        return self.optimizer.zero_grad()

    def step(self):
        self.check_if_set()
        return self.optimizer.step()

    def swap_swa_sgd(self):
        self.check_if_set()
        from torchcontrib.optim import SWA
        assert isinstance(self.optimizer, SWA), self.optimizer
        return self.optimizer.swap_swa_sgd()

    def clip_grad(self):
        self.check_if_set()
        # Todo: report clipped and unclipped
        # Todo: allow clip=None but still report grad_norm
        grad_clips = self.gradient_clipping
        return torch.nn.utils.clip_grad_norm_(self.parameters, grad_clips)

    def to(self, device):
        if device is None:
            return
        self.check_if_set()
        for state in self.optimizer.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.to(device)

    def cpu(self):
        return self.to('cpu')

    def cuda(self, device=None):
        assert device is None or isinstance(device, int), device
        if device is None:
            device = torch.device('cuda')
        return self.to(device)

    def load_state_dict(self, state_dict):
        self.check_if_set()
        return self.optimizer.load_state_dict(state_dict)

    def state_dict(self):
        self.check_if_set()
        return self.optimizer.state_dict()
Exemplo n.º 8
0
def main():

    maxIOU = 0.0
    assert torch.cuda.is_available()
    torch.backends.cudnn.benchmark = True
    model_fname = '../data/model_swa_8/deeplabv3_{0}_epoch%d.pth'.format(
        'crops')
    focal_loss = FocalLoss2d()
    train_dataset = CropSegmentation(train=True, crop_size=args.crop_size)
    #     test_dataset = CropSegmentation(train=False, crop_size=args.crop_size)

    model = torchvision.models.segmentation.deeplabv3_resnet50(
        pretrained=False, progress=True, num_classes=5, aux_loss=True)

    if args.train:
        weight = np.ones(4)
        weight[2] = 5
        weight[3] = 5
        w = torch.FloatTensor(weight).cuda()
        criterion = nn.CrossEntropyLoss()  #ignore_index=255 weight=w
        model = nn.DataParallel(model).cuda()

        for param in model.parameters():
            param.requires_grad = True

        optimizer1 = optim.SGD(model.parameters(),
                               lr=config.lr,
                               momentum=0.9,
                               weight_decay=1e-4)
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer,
                                                   T_max=(args.epochs // 9) +
                                                   1)
        optimizer = SWA(optimizer1)

        dataset_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=args.train,
            pin_memory=True,
            num_workers=args.workers)

        max_iter = args.epochs * len(dataset_loader)
        losses = AverageMeter()
        start_epoch = 0

        if args.resume:
            if os.path.isfile(args.resume):
                print('=> loading checkpoint {0}'.format(args.resume))
                checkpoint = torch.load(args.resume)
                start_epoch = checkpoint['epoch']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                print('=> loaded checkpoint {0} (epoch {1})'.format(
                    args.resume, checkpoint['epoch']))

            else:
                print('=> no checkpoint found at {0}'.format(args.resume))

        for epoch in range(start_epoch, args.epochs):
            scheduler.step(epoch)
            model.train()
            for i, (inputs, target) in enumerate(dataset_loader):

                inputs = Variable(inputs.cuda())
                target = Variable(target.cuda())
                outputs = model(inputs)
                loss1 = focal_loss(outputs['out'], target)
                loss2 = focal_loss(outputs['aux'], target)
                loss01 = loss1 + 0.1 * loss2
                loss3 = lovasz_softmax(outputs['out'], target)
                loss4 = lovasz_softmax(outputs['aux'], target)
                loss02 = loss3 + 0.1 * loss4
                loss = loss01 + loss02
                if np.isnan(loss.item()) or np.isinf(loss.item()):
                    pdb.set_trace()

                losses.update(loss.item(), args.batch_size)

                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                if i > 10 and i % 5 == 0:
                    optimizer.update_swa()

                print('epoch: {0}\t'
                      'iter: {1}/{2}\t'
                      'loss: {loss.val:.4f} ({loss.ema:.4f})'.format(
                          epoch + 1, i + 1, len(dataset_loader), loss=losses))

            if epoch > 5:
                torch.save(
                    {
                        'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                    }, model_fname % (epoch + 1))
        optimizer.swap_swa_sgd()
        torch.save(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, model_fname % (665 + 1))
Exemplo n.º 9
0
def main(args):

    # Parameters for toy data and experiments
    plt.style.use('dark_background')
    rng_seed = 4200
    np.random.seed(rng_seed)

    wstar = torch.tensor([[0.973, 1.144]], dtype=torch.float)
    xs = torch.tensor(np.random.randn(50, 2), dtype=torch.float)
    labels = torch.mm(xs, wstar.T)

    p = torch.tensor(np.random.uniform(0.05, 1.0, xs.shape[0]), dtype=torch.float)
    ips = 1.0 / p
    n_iters = 500
    plot_every = n_iters // 10
    arrow_width = 0.012 * 0.01
    legend = {}
    data = {}
    colors = {}
    interpsize = 30

    # Figure
    fig = plt.figure(figsize=(8.4,5.4), dpi=150)
    ax = fig.add_subplot(111)
    ax.set_xlim([-0.3, 1.8])
    ax.set_ylim([-0.3, 1.8])

    # The loss function we want to optimize
    def loss_fn(out, y, mult):
        l2loss = (out - y) ** 2.0
        logl2loss = torch.log(1.0 + (out - y) ** 2.0)
        return torch.mean(mult * l2loss)

    # True IPS-weighted loss over all datapoints, used for plotting contour
    def f(x1, x2):
        w = torch.tensor([[x1], [x2]], dtype=torch.float)
        o = torch.mm(xs, w)
        return float(loss_fn(o, torch.mm(xs, wstar.reshape((2, 1))), ips))

    # Plot contour
    true_x1 = np.linspace(float(wstar[0, 0]) - 1.5, float(wstar[0, 0]) + 0.8) # - 1.5 / + 1.0
    true_x2 = np.linspace(float(wstar[0, 1]) - 1.5, float(wstar[0, 1]) + 1.2) # - 1.5 / + 1.0
    true_x1, true_x2 = np.meshgrid(true_x1, true_x2)
    true_y = np.array([
        [f(true_x1[i1, i2], true_x2[i1, i2]) for i2 in range(len(true_x2))]
        for i1 in range(len(true_x1))
    ])
    ax.contour(true_x1, true_x2, true_y, levels=10, colors='white', alpha=0.45)
    plt.plot(0.0, 0.0, marker='o', markersize=6, color="white")
    plt.plot(wstar[0, 0], wstar[0, 1], marker='*', markersize=6, color="white")

    # IPS-weighted approach
    for color_index, lr in enumerate([0.1, 0.01, 0.03]):  # 0.01, 0.03, 0.05, 0.1, 0.3
        if lr == 0.01:
            color = "violet"
        elif lr == 0.03:
            color = "orange"
        else:
            color = "lightgreen"
        #color = "C%d" % (color_index + 2)
        model = torch.nn.Linear(2, 1)
        optimizer = torch.optim.SGD(model.parameters(), lr=lr)
        optimizer = SWA(optimizer, swa_start=0, swa_freq=1, swa_lr=lr)
        with torch.no_grad():
            model.bias.zero_()
            model.weight.zero_()
        old_weights = np.copy(model.weight.data.numpy())
        np.random.seed(rng_seed + color_index + 1)# + color_index + 1)
        label = f"IPS-SGD ($\\eta={lr}$)"
        data[label] = np.zeros((3, n_iters * interpsize + 1))
        colors[label] = color
        for t in range(n_iters):
            i = np.random.randint(xs.shape[0])
            x = xs[i, :]
            y = labels[i]
            optimizer.zero_grad()
            o = model(x)
            l = loss_fn(o, y, ips[i])
            l.backward()
            optimizer.step()

            # Record current iteration performance and location
            optimizer.swap_swa_sgd()
            x, y = model.weight.data.numpy()[0]
            optimizer.swap_swa_sgd()
            old_x, old_y, old_z = data[label][:, t * interpsize]
            xr = np.linspace(old_x, x, num=interpsize)
            yr = np.linspace(old_y, y, num=interpsize)
            for i in range(interpsize):
                data[label][:, 1 + t * interpsize + i] = np.array([
                    xr[i], yr[i], f(xr[i], yr[i])])
            #data[label][:, t] = np.array([x, y, f(x, y)])


    # Sample based approach
    lr = 10.0 # 1.0
    model = torch.nn.Linear(2, 1)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    optimizer = SWA(optimizer, swa_start=0, swa_freq=1, swa_lr=lr)
    with torch.no_grad():
        model.bias.zero_()
        model.weight.zero_()
    old_weights = np.copy(model.weight.data.numpy())
    sample_probs = np.array(ips / torch.sum(ips))
    Mbar = float(np.mean(sample_probs))
    np.random.seed(rng_seed - 1)
    label = f"\\textsc{{CounterSample}} ($\\eta={lr}$)"
    data[label] = np.zeros((3, n_iters * interpsize + 1))
    for t in range(n_iters):
        i = np.argwhere(np.random.multinomial(1, sample_probs) == 1.0)[0, 0]
        x = xs[i, :]
        y = labels[i]
        optimizer.zero_grad()
        o = model(x)
        l = loss_fn(o, y, Mbar)
        l.backward()
        optimizer.step()

        # Record current iteration location and performance
        optimizer.swap_swa_sgd()
        x, y = model.weight.data.numpy()[0]
        optimizer.swap_swa_sgd()
        old_x, old_y, old_z = data[label][:, t * interpsize]
        xr = np.linspace(old_x, x, num=interpsize)
        yr = np.linspace(old_y, y, num=interpsize)
        for i in range(interpsize):
            data[label][:, 1 + t * interpsize + i] = np.array([
                xr[i], yr[i], f(xr[i], yr[i])])
        colors[label] = "deepskyblue"

    # Print summary to quickly find performance at convergence
    for label in data.keys():
        print(f"{label}: {data[label][2, -1]}")

    # Create legend
    lines = {}
    for label in data.keys():
        line = data[label]
        lines[label] = ax.plot(line[:, 0], line[:, 1], color=colors[label], label=label, linewidth=2.0)
        legend[colors[label]] = label

    ax.set_xlabel('$w_1$')
    ax.set_ylabel('$w_2$')
    legend_lines = [
        lines[legend["deepskyblue"]][0],
        lines[legend["orange"]][0],
        lines[legend["violet"]][0],
        lines[legend["lightgreen"]][0]]
    legend_labels = [
        "\\textsc{CounterSample}",
        "IPS-SGD (best learning rate)",
        "IPS-SGD (learning rate too small)",
        "IPS-SGD (learning rate too large)"]
    legend_artist = ax.legend(legend_lines, legend_labels, loc='lower right') #, bbox_to_anchor=(1.0 - 0.3, 0.25))

    # Update function for animation
    n_frames = n_iters * 2
    n_data = n_iters * interpsize
    from math import floor
    def transform_num(num):
        x = 3 * ((1.0 * num) / n_frames)
        y = (((x + 0.5)**2 - 0.5**2) / 12.0)
        return floor(y * n_data)

    def update_lines(num):
        print(f"frame {num:4d} / {n_frames:4d} [{num / n_frames * 100:.0f}%]", end="\r")
        num = transform_num(num)
        out = [legend_artist]
        for label in data.keys():
            line = lines[label][0]
            d = data[label]
            line.set_data(d[0:2, :num])
            out.append(line)
        return out

    # Write animation to file
    line_ani = animation.FuncAnimation(fig, update_lines, n_frames, interval=100, blit=False, repeat_delay=3000)
    writer = animation.FFMpegWriter(fps=60, codec='h264')   # for keynote
    line_ani.save(args.out, writer=writer)
    print("\033[K\n")
Exemplo n.º 10
0
def main():
    if args.config_path:
        if args.config_path in CONFIG_TREATER:
            load_path = CONFIG_TREATER[args.config_path]
        elif args.config_path.endswith(".yaml"):
            load_path = args.config_path
        else:
            load_path = "experiments/" + CONFIG_TREATER[
                args.config_path] + ".yaml"
        with open(load_path, 'rb') as fp:
            config = CfgNode.load_cfg(fp)
    else:
        config = None

    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True
    test_model = None
    max_epoch = config.TRAIN.NUM_EPOCHS
    print('data folder: ', args.data_folder)
    torch.backends.cudnn.benchmark = True

    # WORLD_SIZE Generated by torch.distributed.launch.py
    #num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
    #is_distributed = num_gpus > 1
    #if is_distributed:
    #    torch.cuda.set_device(args.local_rank)
    #    torch.distributed.init_process_group(
    #        backend="nccl", init_method="env://",
    #    )

    model = get_model(config)
    model_loss = ModelLossWraper(
        model,
        config.TRAIN.CLASS_WEIGHTS,
        config.MODEL.IS_DISASTER_PRED,
        config.MODEL.IS_SPLIT_LOSS,
    ).cuda()

    #if args.local_rank == 0:
    #from IPython import embed; embed()

    #if is_distributed:
    #    model_loss = nn.SyncBatchNorm.convert_sync_batchnorm(model_loss)
    #    model_loss = nn.parallel.DistributedDataParallel(
    #        model_loss#, device_ids=[args.local_rank], output_device=args.local_rank
    #    )

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    if torch.cuda.device_count() > 1:
        model_loss = nn.DataParallel(model_loss)

    model_loss.to(device)
    cpucount = multiprocessing.cpu_count()

    if config.mode.startswith("single"):
        trainset_loaders = {}
        loader_len = 0
        for disaster in disaster_list[config.mode[6:]]:
            trainset = XView2Dataset(args.data_folder,
                                     rgb_bgr='rgb',
                                     preprocessing={
                                         'flip': True,
                                         'scale': config.TRAIN.MULTI_SCALE,
                                         'crop': config.TRAIN.CROP_SIZE,
                                     },
                                     mode="singletrain",
                                     single_disaster=disaster)
            if len(trainset) > 0:
                train_sampler = None

                trainset_loader = torch.utils.data.DataLoader(
                    trainset,
                    batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
                    shuffle=train_sampler is None,
                    pin_memory=True,
                    drop_last=True,
                    sampler=train_sampler,
                    num_workers=cpucount if cpucount < 16 else cpucount // 3)

                trainset_loaders[disaster] = trainset_loader
                loader_len += len(trainset_loader)
                print("added disaster {} with {} samples".format(
                    disaster, len(trainset)))
            else:
                print("skipping disaster ", disaster)

    else:

        trainset = XView2Dataset(args.data_folder,
                                 rgb_bgr='rgb',
                                 preprocessing={
                                     'flip': True,
                                     'scale': config.TRAIN.MULTI_SCALE,
                                     'crop': config.TRAIN.CROP_SIZE,
                                 },
                                 mode=config.mode)

        #if is_distributed:
        #    train_sampler = DistributedSampler(trainset)
        #else:
        train_sampler = None

        trainset_loader = torch.utils.data.DataLoader(
            trainset,
            batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
            shuffle=train_sampler is None,
            pin_memory=True,
            drop_last=True,
            sampler=train_sampler,
            num_workers=multiprocessing.cpu_count())
        loader_len = len(trainset_loader)

    model.train()

    lr_init = config.TRAIN.LR
    optimizer = torch.optim.SGD(
        [{
            'params': filter(lambda p: p.requires_grad, model.parameters()),
            'lr': lr_init
        }],
        lr=lr_init,
        momentum=0.9,
        weight_decay=0.,
        nesterov=False,
    )

    num_iters = max_epoch * loader_len

    if config.SWA:
        swa_start = num_iters
        optimizer = SWA(
            optimizer,
            swa_start=swa_start,
            swa_freq=4 * loader_len,
            swa_lr=0.001
        )  #SWA(optimizer, swa_start = None, swa_freq = None, swa_lr = None)#
        #scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, 0.0001, 0.05, step_size_up=1, step_size_down=2*len(trainset_loader)-1, mode='triangular', gamma=1.0, scale_fn=None, scale_mode='cycle', cycle_momentum=False, base_momentum=0.8, max_momentum=0.9, last_epoch=-1)
        lr = 0.0001
        #model.load_state_dict(torch.load("ckpt/dual-hrnet/hrnet_450", map_location='cpu')['state_dict'])
        #print("weights loaded")
        max_epoch = max_epoch + 40

    start_epoch = 0
    losses = AverageMeter()
    model.train()
    cur_iters = 0 if start_epoch == 0 else None
    for epoch in range(start_epoch, max_epoch):

        if config.mode.startswith("single"):
            all_batches = []
            total_len = 0
            for disaster in sorted(list(trainset_loaders.keys())):
                all_batches += [
                    (disaster, idx)
                    for idx in range(len(trainset_loaders[disaster]))
                ]
                total_len += len(trainset_loaders[disaster].dataset)
            all_batches = random.sample(all_batches, len(all_batches))
            iterators = {
                disaster: iter(trainset_loaders[disaster])
                for disaster in trainset_loaders.keys()
            }
            if cur_iters is not None:
                cur_iters += len(all_batches)
            else:
                cur_iters = epoch * len(all_batches)

            for i, (disaster, idx) in enumerate(all_batches):
                lr = optimizer.param_groups[0]['lr']
                if not config.SWA or epoch < swa_start:
                    lr = adjust_learning_rate(optimizer, lr_init, num_iters,
                                              i + cur_iters)
                samples = next(iterators[disaster])
                inputs_pre = samples['pre_img'].to(device)
                inputs_post = samples['post_img'].to(device)
                target = samples['mask_img'].to(device)
                #disaster_target = samples['disaster'].to(device)

                loss = model_loss(inputs_pre, inputs_post,
                                  target)  #, disaster_target)

                loss_sum = torch.sum(loss).detach().cpu()
                if np.isnan(loss_sum) or np.isinf(loss_sum):
                    print('check')
                losses.update(loss_sum, 4)  # batch size

                loss = torch.sum(loss)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

                if args.local_rank == 0 and i % 10 == 0:
                    logger.info('epoch: {0}\t'
                                'iter: {1}/{2}\t'
                                'lr: {3:.6f}\t'
                                'loss: {loss.val:.4f} ({loss.ema:.4f})\t'
                                'disaster: {dis}'.format(epoch + 1,
                                                         i + 1,
                                                         len(all_batches),
                                                         lr,
                                                         loss=losses,
                                                         dis=disaster))

            del iterators

        else:
            cur_iters = epoch * len(trainset_loader)

            for i, samples in enumerate(trainset_loader):
                lr = optimizer.param_groups[0]['lr']
                if not config.SWA or epoch < swa_start:
                    lr = adjust_learning_rate(optimizer, lr_init, num_iters,
                                              i + cur_iters)

                inputs_pre = samples['pre_img'].to(device)
                inputs_post = samples['post_img'].to(device)
                target = samples['mask_img'].to(device)
                #disaster_target = samples['disaster'].to(device)

                loss = model_loss(inputs_pre, inputs_post,
                                  target)  #, disaster_target)

                loss_sum = torch.sum(loss).detach().cpu()
                if np.isnan(loss_sum) or np.isinf(loss_sum):
                    print('check')
                losses.update(loss_sum, 4)  # batch size

                loss = torch.sum(loss)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

                #if args.swa == "True":
                #scheduler.step()
                #if epoch%4 == 3 and i == len(trainset_loader)-2:
                #    optimizer.update_swa()

                if args.local_rank == 0 and i % 10 == 0:
                    logger.info('epoch: {0}\t'
                                'iter: {1}/{2}\t'
                                'lr: {3:.6f}\t'
                                'loss: {loss.val:.4f} ({loss.ema:.4f})'.format(
                                    epoch + 1,
                                    i + 1,
                                    len(trainset_loader),
                                    lr,
                                    loss=losses))

        if args.local_rank == 0:
            if (epoch + 1) % 50 == 0 and test_model is None:
                torch.save(
                    {
                        'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                    }, os.path.join(ckpts_save_dir, 'hrnet_%s' % (epoch + 1)))
    if config.SWA:
        torch.save(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, os.path.join(ckpts_save_dir, 'hrnet_%s' % ("preSWA")))
        optimizer.swap_swa_sgd()
        bn_loader = torch.utils.data.DataLoader(
            trainset,
            batch_size=2,
            shuffle=train_sampler is None,
            pin_memory=True,
            drop_last=True,
            sampler=train_sampler,
            num_workers=multiprocessing.cpu_count())
        bn_update(bn_loader, model, device='cuda')
        torch.save(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, os.path.join(ckpts_save_dir, 'hrnet_%s' % ("SWA")))
Exemplo n.º 11
0
class Model() :
    def __init__(self, configuration, pre_embed=None) :
        configuration = deepcopy(configuration)
        self.configuration = deepcopy(configuration)

        configuration['model']['encoder']['pre_embed'] = pre_embed

        encoder_copy = deepcopy(configuration['model']['encoder'])
        self.Pencoder = Encoder.from_params(Params(configuration['model']['encoder'])).to(device)
        self.Qencoder = Encoder.from_params(Params(encoder_copy)).to(device)

        configuration['model']['decoder']['hidden_size'] = self.Pencoder.output_size
        self.decoder = AttnDecoderQA.from_params(Params(configuration['model']['decoder'])).to(device)

        self.bsize = configuration['training']['bsize']

        self.adversary_multi = AdversaryMulti(self.decoder)

        weight_decay = configuration['training'].get('weight_decay', 1e-5)
        self.params = list(self.Pencoder.parameters()) + list(self.Qencoder.parameters()) + list(self.decoder.parameters())
        self.optim = torch.optim.Adam(self.params, weight_decay=weight_decay, amsgrad=True)
        # self.optim = torch.optim.Adagrad(self.params, lr=0.05, weight_decay=weight_decay)
        self.criterion = nn.CrossEntropyLoss()

        import time
        dirname = configuration['training']['exp_dirname']
        basepath = configuration['training'].get('basepath', 'outputs')
        self.time_str = time.ctime().replace(' ', '_')
        self.dirname = os.path.join(basepath, dirname, self.time_str)

        self.swa_settings = configuration['training']['swa']
        if self.swa_settings[0]:
            self.swa_all_optim = SWA(self.optim)
            self.running_norms = []

    @classmethod
    def init_from_config(cls, dirname, **kwargs) :
        config = json.load(open(dirname + '/config.json', 'r'))
        config.update(kwargs)
        obj = cls(config)
        obj.load_values(dirname)
        return obj

    def get_param_buffer_norms(self):
        for p in self.swa_all_optim.param_groups[0]['params']:
            param_state = self.swa_all_optim.state[p]
            if 'swa_buffer' not in param_state:
                self.swa_all_optim.update_swa()

        norms = []
        for p in np.array(self.swa_all_optim.param_groups[0]['params'])[
            [1, 2, 5, 6, 10, 11, 14, 15, 18, 20, 24, 26]]:
            param_state = self.swa_all_optim.state[p]
            buf = np.squeeze(
                param_state['swa_buffer'].cpu().numpy())
            cur_state = np.squeeze(p.data.cpu().numpy())
            norm = np.linalg.norm(buf - cur_state)
            norms.append(norm)
        if self.swa_settings[3] == 2:
            return np.max(norms)
        return np.mean(norms)

    def total_iter_num(self):
        return self.swa_all_optim.param_groups[0]['step_counter']

    def iter_for_swa_update(self, iter_num):
        return iter_num > self.swa_settings[1] \
               and iter_num % self.swa_settings[2] == 0


    def check_and_update_swa(self):
        if self.iter_for_swa_update(self.total_iter_num()):
            cur_step_diff_norm = self.get_param_buffer_norms()
            if self.swa_settings[3] == 0:
                self.swa_all_optim.update_swa()
                return
            if not self.running_norms:
                running_mean_norm = 0
            else:
                running_mean_norm = np.mean(self.running_norms)

            if cur_step_diff_norm > running_mean_norm:
                self.swa_all_optim.update_swa()
                self.running_norms = [cur_step_diff_norm]
            elif cur_step_diff_norm > 0:
                self.running_norms.append(cur_step_diff_norm)

    def train(self, train_data, train=True) :
        docs_in = train_data.P
        question_in = train_data.Q
        entity_masks_in = train_data.E
        target_in = train_data.A

        sorting_idx = get_sorting_index_with_noise_from_lengths([len(x) for x in docs_in], noise_frac=0.1)
        docs = [docs_in[i] for i in sorting_idx]
        questions = [question_in[i] for i in sorting_idx]
        entity_masks = [entity_masks_in[i] for i in sorting_idx]
        target = [target_in[i] for i in sorting_idx]
        
        self.Pencoder.train()
        self.Qencoder.train()
        self.decoder.train()

        bsize = self.bsize
        N = len(questions)
        loss_total = 0

        batches = list(range(0, N, bsize))
        batches = shuffle(batches)

        for n in tqdm(batches) :
            torch.cuda.empty_cache()
            batch_doc = docs[n:n+bsize]
            batch_ques = questions[n:n+bsize]
            batch_entity_masks = entity_masks[n:n+bsize]

            batch_doc = BatchHolder(batch_doc)
            batch_ques = BatchHolder(batch_ques)

            batch_data = BatchMultiHolder(P=batch_doc, Q=batch_ques)
            batch_data.entity_mask = torch.ByteTensor(np.array(batch_entity_masks)).to(device)

            self.Pencoder(batch_data.P)
            self.Qencoder(batch_data.Q)
            self.decoder(batch_data)

            batch_target = target[n:n+bsize]
            batch_target = torch.LongTensor(batch_target).to(device)

            ce_loss = self.criterion(batch_data.predict, batch_target)

            loss = ce_loss

            if hasattr(batch_data, 'reg_loss') :
                loss += batch_data.reg_loss

            if train :
                if self.swa_settings[0]:
                    self.check_and_update_swa()

                    self.swa_all_optim.zero_grad()
                    loss.backward()
                    self.swa_all_optim.step()
                else:
                    self.optim.zero_grad()
                    loss.backward()
                    self.optim.step()

            loss_total += float(loss.data.cpu().item())
        if self.swa_settings[0] and self.swa_all_optim.param_groups[0][
            'step_counter'] > self.swa_settings[1]:
            print("\nSWA swapping\n")
            # self.attn_optim.swap_swa_sgd()
            # self.encoder_optim.swap_swa_sgd()
            # self.decoder_optim.swap_swa_sgd()
            self.swa_all_optim.swap_swa_sgd()
            self.running_norms = []
        return loss_total*bsize/N

    def evaluate(self, data) :
        docs = data.P
        questions = data.Q
        entity_masks = data.E
        
        self.Pencoder.train()
        self.Qencoder.train()
        self.decoder.train()
        
        bsize = self.bsize
        N = len(questions)

        outputs = []
        attns = []
        scores = []
        for n in tqdm(range(0, N, bsize)) :
            torch.cuda.empty_cache()
            batch_doc = docs[n:n+bsize]
            batch_ques = questions[n:n+bsize]
            batch_entity_masks = entity_masks[n:n+bsize]

            batch_doc = BatchHolder(batch_doc)
            batch_ques = BatchHolder(batch_ques)

            batch_data = BatchMultiHolder(P=batch_doc, Q=batch_ques)
            batch_data.entity_mask = torch.ByteTensor(np.array(batch_entity_masks)).to(device)

            self.Pencoder(batch_data.P)
            self.Qencoder(batch_data.Q)
            self.decoder(batch_data)

            prediction_scores = batch_data.predict.cpu().data.numpy()
            batch_data.predict = torch.argmax(batch_data.predict, dim=-1)
            if self.decoder.use_attention :
                attn = batch_data.attn
                attns.append(attn.cpu().data.numpy())

            predict = batch_data.predict.cpu().data.numpy()
            outputs.append(predict)
            scores.append(prediction_scores)

            

        outputs = [x for y in outputs for x in y]
        attns = [x for y in attns for x in y]
        scores = [x for y in scores for x in y]

        return outputs, attns, scores

    def save_values(self, use_dirname=None, save_model=True) :
        if use_dirname is not None :
            dirname = use_dirname
        else :
            dirname = self.dirname
        os.makedirs(dirname, exist_ok=True)
        shutil.copy2(file_name, dirname + '/')
        json.dump(self.configuration, open(dirname + '/config.json', 'w'))

        if save_model :
            torch.save(self.Pencoder.state_dict(), dirname + '/encP.th')
            torch.save(self.Qencoder.state_dict(), dirname + '/encQ.th')
            torch.save(self.decoder.state_dict(), dirname + '/dec.th')

        return dirname

    def load_values(self, dirname) :
        self.Pencoder.load_state_dict(torch.load(dirname + '/encP.th'))
        self.Qencoder.load_state_dict(torch.load(dirname + '/encQ.th'))
        self.decoder.load_state_dict(torch.load(dirname + '/dec.th'))

    def permute_attn(self, data, num_perm=100) :
        docs = data.P
        questions = data.Q
        entity_masks = data.E

        self.Pencoder.train()
        self.Qencoder.train()
        self.decoder.train()

        bsize = self.bsize
        N = len(questions)

        permutations_predict = []
        permutations_diff = []

        for n in tqdm(range(0, N, bsize)) :
            torch.cuda.empty_cache()
            batch_doc = docs[n:n+bsize]
            batch_ques = questions[n:n+bsize]
            batch_entity_masks = entity_masks[n:n+bsize]

            batch_doc = BatchHolder(batch_doc)
            batch_ques = BatchHolder(batch_ques)

            batch_data = BatchMultiHolder(P=batch_doc, Q=batch_ques)
            batch_data.entity_mask = torch.ByteTensor(np.array(batch_entity_masks)).to(device)

            self.Pencoder(batch_data.P)
            self.Qencoder(batch_data.Q)
            self.decoder(batch_data)

            predict_true = batch_data.predict.clone().detach()

            batch_perms_predict = np.zeros((batch_data.P.B, num_perm))
            batch_perms_diff = np.zeros((batch_data.P.B, num_perm))

            for i in range(num_perm) :
                batch_data.permute = True
                self.decoder(batch_data)

                predict = torch.argmax(batch_data.predict, dim=-1)
                batch_perms_predict[:, i] = predict.cpu().data.numpy()
            
                predict_difference = self.adversary_multi.output_diff(batch_data.predict, predict_true)
                batch_perms_diff[:, i] = predict_difference.squeeze(-1).cpu().data.numpy()
                
            permutations_predict.append(batch_perms_predict)
            permutations_diff.append(batch_perms_diff)

        permutations_predict = [x for y in permutations_predict for x in y]
        permutations_diff = [x for y in permutations_diff for x in y]
        
        return permutations_predict, permutations_diff

    def adversarial_multi(self, data) :
        docs = data.P
        questions = data.Q
        entity_masks = data.E

        self.Pencoder.eval()
        self.Qencoder.eval()
        self.decoder.eval()

        print(self.adversary_multi.K)
        
        self.params = list(self.Pencoder.parameters()) + list(self.Qencoder.parameters()) + list(self.decoder.parameters())

        for p in self.params :
            p.requires_grad = False

        bsize = self.bsize
        N = len(questions)
        batches = list(range(0, N, bsize))

        outputs, attns, diffs = [], [], []

        for n in tqdm(batches) :
            torch.cuda.empty_cache()
            batch_doc = docs[n:n+bsize]
            batch_ques = questions[n:n+bsize]
            batch_entity_masks = entity_masks[n:n+bsize]

            batch_doc = BatchHolder(batch_doc)
            batch_ques = BatchHolder(batch_ques)

            batch_data = BatchMultiHolder(P=batch_doc, Q=batch_ques)
            batch_data.entity_mask = torch.ByteTensor(np.array(batch_entity_masks)).to(device)

            self.Pencoder(batch_data.P)
            self.Qencoder(batch_data.Q)
            self.decoder(batch_data)

            self.adversary_multi(batch_data)

            predict_volatile = torch.argmax(batch_data.predict_volatile, dim=-1)
            outputs.append(predict_volatile.cpu().data.numpy())
            
            attn = batch_data.attn_volatile
            attns.append(attn.cpu().data.numpy())

            predict_difference = self.adversary_multi.output_diff(batch_data.predict_volatile, batch_data.predict.unsqueeze(1))
            diffs.append(predict_difference.cpu().data.numpy())

        outputs = [x for y in outputs for x in y]
        attns = [x for y in attns for x in y]
        diffs = [x for y in diffs for x in y]
        
        return outputs, attns, diffs

    def gradient_mem(self, data) :
        docs = data.P
        questions = data.Q
        entity_masks = data.E

        self.Pencoder.train()
        self.Qencoder.train()
        self.decoder.train()
        
        bsize = self.bsize
        N = len(questions)

        grads = {'XxE' : [], 'XxE[X]' : [], 'H' : []}

        for n in range(0, N, bsize) :
            torch.cuda.empty_cache()
            batch_doc = docs[n:n+bsize]
            batch_ques = questions[n:n+bsize]
            batch_entity_masks = entity_masks[n:n+bsize]

            batch_doc = BatchHolder(batch_doc)
            batch_ques = BatchHolder(batch_ques)

            batch_data = BatchMultiHolder(P=batch_doc, Q=batch_ques)
            batch_data.entity_mask = torch.ByteTensor(np.array(batch_entity_masks)).to(device)

            batch_data.P.keep_grads = True
            batch_data.detach = True

            self.Pencoder(batch_data.P)
            self.Qencoder(batch_data.Q)
            self.decoder(batch_data)
            
            max_predict = torch.argmax(batch_data.predict, dim=-1)
            prob_predict = nn.Softmax(dim=-1)(batch_data.predict)

            max_class_prob = torch.gather(prob_predict, -1, max_predict.unsqueeze(-1))
            max_class_prob.sum().backward()

            g = batch_data.P.embedding.grad
            em = batch_data.P.embedding
            g1 = (g * em).sum(-1)
            
            grads['XxE[X]'].append(g1.cpu().data.numpy())
            
            g1 = (g * self.Pencoder.embedding.weight.sum(0)).sum(-1)
            grads['XxE'].append(g1.cpu().data.numpy())
            
            g1 = batch_data.P.hidden.grad.sum(-1)
            grads['H'].append(g1.cpu().data.numpy())


        for k in grads :
            grads[k] = [x for y in grads[k] for x in y]
                    
        return grads       

    def remove_and_run(self, data) :
        docs = data.P
        questions = data.Q
        entity_masks = data.E

        self.Pencoder.train()
        self.Qencoder.train()
        self.decoder.train()
        
        bsize = self.bsize
        N = len(questions)
        output_diffs = []

        for n in tqdm(range(0, N, bsize)) :
            torch.cuda.empty_cache()
            batch_doc = docs[n:n+bsize]
            batch_ques = questions[n:n+bsize]
            batch_entity_masks = entity_masks[n:n+bsize]

            batch_doc = BatchHolder(batch_doc)
            batch_ques = BatchHolder(batch_ques)

            batch_data = BatchMultiHolder(P=batch_doc, Q=batch_ques)
            batch_data.entity_mask = torch.ByteTensor(np.array(batch_entity_masks)).to(device)

            self.Pencoder(batch_data.P)
            self.Qencoder(batch_data.Q)
            self.decoder(batch_data)

            po = np.zeros((batch_data.P.B, batch_data.P.maxlen))

            for i in range(1, batch_data.P.maxlen - 1) :
                batch_doc = BatchHolder(docs[n:n+bsize])

                batch_doc.seq = torch.cat([batch_doc.seq[:, :i], batch_doc.seq[:, i+1:]], dim=-1)
                batch_doc.lengths = batch_doc.lengths - 1
                batch_doc.masks = torch.cat([batch_doc.masks[:, :i], batch_doc.masks[:, i+1:]], dim=-1)

                batch_data_loop = BatchMultiHolder(P=batch_doc, Q=batch_ques)
                batch_data_loop.entity_mask = torch.ByteTensor(np.array(batch_entity_masks)).to(device)

                self.Pencoder(batch_data_loop.P)
                self.decoder(batch_data_loop)

                predict_difference = self.adversary_multi.output_diff(batch_data_loop.predict, batch_data.predict)

                po[:, i] = predict_difference.squeeze(-1).cpu().data.numpy()

            output_diffs.append(po)

        output_diffs = [x for y in output_diffs for x in y]
        
        return output_diffs
Exemplo n.º 12
0
class TPUFitter:
    
    def __init__(self, model, device, config, base_model_path='/', model_name='unnamed', model_prefix='roberta', model_version='v1', out_path='/', log_path='/'):
        self.log_path = Path(log_path, 'log').with_suffix('.txt')
        self.log(f'TPUFitter started to initilized.', direct_out=True)
        self.config = config
        self.epoch = 0
        self.base_model_path = base_model_path
        self.model_name = model_name
        self.model_version = model_version
        self.model_path = Path(self.base_model_path, self.model_name, self.model_version)
        
        self.out_path = out_path
        self.node_path = Path(self.out_path, 'node_submissions')
        self.create_dir_structure()

        self.model = model
        self.device = device
        # whether use stochastic weight avaraging
        self.use_SWA = config.use_SWA
        # whether use different lr for backbone and classifier head
        self.use_diff_lr = config.use_diff_lr
        
        self._set_optimizer_scheduler()
        self.criterion = config.criterion
        self.best_score = -1.0
        self.log(f'Fitter prepared. Device is {self.device}', direct_out=True)
    
    def create_dir_structure(self):
        self.node_path.mkdir(parents=True, exist_ok=True)
        self.log(f'**** Directory structure created ****', direct_out=True)
    
    def _set_optimizer_scheduler(self):
        self.log(f'Optimizer and scheduler started to initilized.', direct_out=True)
        def is_backbone(n):
            return 'backbone' in n

        param_optimizer = list(self.model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']

        # use different learning rate for backbone transformer and classifier head
        if self.use_diff_lr:
            backbone_lr, head_lr = self.config.lr*xm.xrt_world_size(), self.config.lr*xm.xrt_world_size()*500
            optimizer_grouped_parameters = [
                # {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001},
                # {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
                {"params": [p for n, p in param_optimizer if is_backbone(n)], "lr": backbone_lr},
                {"params": [p for n, p in param_optimizer if not is_backbone(n)], "lr": head_lr}
            ]
            self.log(f'Different Learning rate for backbone: {backbone_lr} head:{head_lr}')
        else:
            optimizer_grouped_parameters = [
                {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001},
                {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
                ]
        
        try:
            self.optimizer = AdamW(optimizer_grouped_parameters, lr=self.config.lr*xm.xrt_world_size())
            # self.optimizer = SGD(optimizer_grouped_parameters, lr=self.config.lr*xm.xrt_world_size(), momentum=0.9)
        except:
            param_g_1 = [p for n, p in param_optimizer if is_backbone(n)]
            param_g_2 = [p for n, p in param_optimizer if not is_backbone(n)]
            param_intersect = list(set(param_g_1) & set(param_g_2))
            self.log(f'intersect: {param_intersect}', direct_out=True)

        if self.use_SWA:
            self.optimizer = SWA(self.optimizer)
        
        if 'num_training_steps' in self.config.scheduler_params:
            num_training_steps = int(self.config.train_lenght / self.config.batch_size / xm.xrt_world_size() * self.config.n_epochs)
            self.log(f'Number of training steps: {num_training_steps}', direct_out=True)
            self.config.scheduler_params['num_training_steps'] = num_training_steps
        
        self.scheduler = self.config.SchedulerClass(self.optimizer, **self.config.scheduler_params)

    def fit(self, train_loader, validation_loader, n_epochs=None):
        self.log(f'**** Fitting process has been started ****', direct_out=True)
        if n_epochs is None:
            n_epochs = self.config.n_epochs
        
        for e in range(n_epochs):
            if self.config.verbose:
                lr = self.optimizer.param_groups[0]['lr']
                timestamp = datetime.utcnow().isoformat()
                self.log(f'\n{timestamp}\nLR: {lr} \nEpoch:{e}')

            t = time.time()
            para_loader = pl.ParallelLoader(train_loader, [self.device])
            losses, final_scores = self.train_one_epoch(para_loader.per_device_loader(self.device), e)
            
            self.log(f'[RESULT]: Train. Epoch: {self.epoch}, loss: {losses.avg:.5f}, final_score: {final_scores.avg:.5f}, time: {(time.time() - t):.5f}')

            t = time.time()
            para_loader = pl.ParallelLoader(validation_loader, [self.device])
            
            # swap SWA weights for validation
            if self.use_SWA:
                self.log('Swapping SWA weights for validation', direct_out=True)
                self.optimizer.swap_swa_sgd()
            
            losses, final_scores, threshold = self.validation(para_loader.per_device_loader(self.device))
            self.log(f'[RESULT]: Validation. Epoch: {self.epoch}, loss: {losses.avg:.5f}, final_score: {final_scores.avg:.5f}, best_th: {threshold.find:.3f}, time: {(time.time() - t):.5f}')
            # swap back to normal weights to continue training
            if self.use_SWA:
                self.log('Swapping back to original weights for validation', direct_out=True)
                self.optimizer.swap_swa_sgd()
            
            if final_scores.avg > self.best_score:
                self.best_score = final_scores.avg
                self.save('best_model')
                self.log('Best model has been updated', direct_out=True)
                # after one epoch, update SWA model if validation score is increased
                if self.use_SWA:
                    self.optimizer.update_swa()
                    self.log('SWA model weights have been updated', direct_out=True)

            if self.config.validation_scheduler:
                # self.scheduler.step(metrics=final_scores.avg)
                self.scheduler.step()
            
            self.epoch += 1
    
    def run_tuning_and_inference(self, test_loader, validation_loader, validation_tune_loader, n_epochs):
        self.log('******Validation tuning and inference is started*****', direct_out=True)
        self.run_validation_tuning(validation_loader, validation_tune_loader, n_epochs)
        para_loader = pl.ParallelLoader(test_loader, [self.device])
        self.run_inference(para_loader.per_device_loader(self.device))
    
    def run_validation_tuning(self, validation_loader, validation_tune_loader, n_epochs):
        self.log('******Validation tuning is started*****', direct_out=True)
        # self.optimizer.param_groups[0]['lr'] = self.config.lr*xm.xrt_world_size() / (epoch + 1)
        self.fit(validation_tune_loader, validation_loader, n_epochs)
    
    def validation(self, val_loader):
        self.log(f'**** Validation process has been started ****', direct_out=True)
        self.model.eval()
        losses = AverageMeter()
        final_scores = RocAucMeter()
        threshold = ThresholdMeter()

        t = time.time()
        for step, (targets, inputs, attention_masks) in enumerate(val_loader):
            self.log(
                f'Valid Step {step}, loss: ' + \
                f'{losses.avg:.5f}, final_score: {final_scores.avg:.5f}, ' + \
                f'time: {(time.time() - t):.5f}', step=step
            )
            with torch.no_grad():
                inputs = inputs.to(self.device, dtype=torch.long) 
                attention_masks = attention_masks.to(self.device, dtype=torch.long) 
                targets = targets.to(self.device, dtype=torch.float) 

                outputs = self.model(inputs, attention_masks)
                loss = self.criterion(outputs, targets)
                
                batch_size = inputs.size(0)

                final_scores.update(targets, outputs)
                losses.update(loss.detach().item(), batch_size)
                threshold.update(targets, outputs)
        
        return losses, final_scores, threshold

    def train_one_epoch(self, train_loader, epoch):
        self.log(f'**** Epoch training has started: {epoch} ****', direct_out=True)
        self.model.train()

        losses = AverageMeter()
        final_scores = RocAucMeter()
        t = time.time()
        for step, (targets, inputs, attention_masks) in enumerate(train_loader):
            self.log(
                f'Train Step {step}, loss: ' + \
                f'{losses.avg:.5f}, final_score: {final_scores.avg:.5f}, ' + \
                f'time: {(time.time() - t):.5f}', step=step
            )

            inputs = inputs.to(self.device, dtype=torch.long)
            attention_masks = attention_masks.to(self.device, dtype=torch.long)
            targets = targets.to(self.device, dtype=torch.float)

            self.optimizer.zero_grad()

            outputs = self.model(inputs, attention_masks)
            loss = self.criterion(outputs, targets)

            batch_size = inputs.size(0)
            
            final_scores.update(targets, outputs)
            losses.update(loss.detach().item(), batch_size)

            loss.backward()
            xm.optimizer_step(self.optimizer)

            if self.config.step_scheduler:
                self.scheduler.step()
        
        return losses, final_scores

    def run_inference(self, test_loader):
        self.log(f'**** Inference process has been started ****', direct_out=True)
        self.model.eval()
        result = {'id': [], 'toxic': []}
        
        t = time.time()
        for step, (ids, inputs, attention_masks) in enumerate(test_loader):
            self.log(f'Prediction Step {step}, time: {(time.time() - t):.5f}', step=step)

            with torch.no_grad():
                inputs = inputs.to(self.device, dtype=torch.long) 
                attention_masks = attention_masks.to(self.device, dtype=torch.long)
                outputs = self.model(inputs, attention_masks)
                toxics = torch.nn.functional.softmax(outputs, dim=1).data.cpu().numpy()[:,1]

            result['id'].extend(ids.cpu().numpy())
            result['toxic'].extend(toxics)

        result = pd.DataFrame(result)
        print(f'Node path is: {self.node_path}')
        node_count = len(list(self.node_path.glob('*.csv')))
        result.to_csv(self.node_path/f'submission_{node_count}_{datetime.utcnow().microsecond}_{random.random()}.csv', index=False)

    def run_pseudolabeling(self, test_loader, epoch):
        losses = AverageMeter()
        final_scores = RocAucMeter()

        self.model.eval()
        
        t = time.time()
        for step, (ids, inputs, attention_masks) in enumerate(test_loader):

            inputs = inputs.to(self.device, dtype=torch.long) 
            attention_masks = attention_masks.to(self.device, dtype=torch.long)
            outputs = self.model(inputs, attention_masks)
            # print(f'Inputs: {inputs} size: {inputs.size()}')
            # print(f'outputs: {outputs} size: {outputs.size()}')
            toxics = torch.nn.functional.softmax(outputs, dim=1)[:,1]
            toxic_mask = (toxics<=0.4) | (toxics>=0.8)
            # print(attention_masks.size())
            toxics = toxics[toxic_mask]
            inputs = inputs[toxic_mask]
            attention_masks = attention_masks[toxic_mask]
            # print(f'toxics: {toxics.size()}')
            # print(f'inputs: {inputs.size()}')
            if toxics.nelement() != 0:
                targets_int = (toxics>self.config.pseudolabeling_threshold).int()
                targets = torch.stack([onehot(2, target) for target in targets_int])
                # print(targets_int)
                
                self.model.train()
                self.log(
                    f'Pseudolabeling Step {step}, loss: ' + \
                    f'{losses.avg:.5f}, final_score: {final_scores.avg:.5f}, ' + \
                    f'time: {(time.time() - t):.5f}', step=step
                )
    
                targets = targets.to(self.device, dtype=torch.float)
    
                self.optimizer.zero_grad()
    
                outputs = self.model(inputs, attention_masks)
                loss = self.criterion(outputs, targets)
    
                batch_size = inputs.size(0)
                
                final_scores.update(targets, outputs)
                losses.update(loss.detach().item(), batch_size)
    
                loss.backward()
                xm.optimizer_step(self.optimizer)
    
                if self.config.step_scheduler:
                    self.scheduler.step()
    
        self.log(f'[RESULT]: Pseudolabeling. Epoch: {epoch}, loss: {losses.avg:.5f}, final_score: {final_scores.avg:.5f}, time: {(time.time() - t):.5f}')

    def get_submission(self, out_dir):
        submission = pd.concat([pd.read_csv(path) for path in (out_dir/'node_submissions').glob('*.csv')]).groupby('id').mean()
        return submission
    
    def save(self, name):
        self.model_path.mkdir(parents=True, exist_ok=True)
        path = (self.model_path/name).with_suffix('.bin')
        
        if self.use_SWA:
            self.optimizer.swap_swa_sgd()

        xm.save(self.model.state_dict(), path)
        self.log(f'Model has been saved')

    def log(self, message, step=None, direct_out=False):
        if direct_out or self.config.verbose:
            if direct_out or step is None or (step is not None and step % self.config.verbose_step == 0):
                xm.master_print(message)
                with open(self.log_path, 'a+') as logger:
                    xm.master_print(f'{message}', logger)
Exemplo n.º 13
0
class Train(object):
    """Train class.
  """
    def __init__(self, ):

        trainds = AlaskaDataIter(cfg.DATA.root_path,
                                 cfg.DATA.train_txt_path,
                                 training_flag=True)
        self.train_ds = DataLoader(trainds,
                                   cfg.TRAIN.batch_size,
                                   num_workers=cfg.TRAIN.process_num,
                                   shuffle=True)

        valds = AlaskaDataIter(cfg.DATA.root_path,
                               cfg.DATA.val_txt_path,
                               training_flag=False)
        self.val_ds = DataLoader(valds,
                                 cfg.TRAIN.batch_size,
                                 num_workers=cfg.TRAIN.process_num,
                                 shuffle=False)

        self.init_lr = cfg.TRAIN.init_lr
        self.warup_step = cfg.TRAIN.warmup_step
        self.epochs = cfg.TRAIN.epoch
        self.batch_size = cfg.TRAIN.batch_size
        self.l2_regularization = cfg.TRAIN.weight_decay_factor

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else 'cpu')

        self.model = CenterNet().to(self.device)

        self.load_weight()

        if 'Adamw' in cfg.TRAIN.opt:

            self.optimizer = torch.optim.AdamW(
                self.model.parameters(),
                lr=self.init_lr,
                eps=1.e-5,
                weight_decay=self.l2_regularization)
        else:
            self.optimizer = torch.optim.SGD(
                self.model.parameters(),
                lr=self.init_lr,
                momentum=0.9,
                weight_decay=self.l2_regularization)

        if cfg.TRAIN.SWA > 0:
            ##use swa
            self.optimizer = SWA(self.optimizer)

        if cfg.TRAIN.mix_precision:
            self.model, self.optimizer = amp.initialize(self.model,
                                                        self.optimizer,
                                                        opt_level="O1")

        self.model = nn.DataParallel(self.model)

        self.ema = EMA(self.model, 0.999)

        self.ema.register()
        ###control vars
        self.iter_num = 0

        # self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,mode='max', patience=3,verbose=True)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, self.epochs, eta_min=1.e-6)

        self.criterion = CenterNetLoss().to(self.device)

    def custom_loop(self):
        """Custom training and testing loop.
    Args:
      train_dist_dataset: Training dataset created using strategy.
      test_dist_dataset: Testing dataset created using strategy.
      strategy: Distribution strategy.
    Returns:
      train_loss, train_accuracy, test_loss, test_accuracy
    """
        def train_epoch(epoch_num):

            summary_loss_cls = AverageMeter()
            summary_loss_wh = AverageMeter()
            self.model.train()

            if cfg.MODEL.freeze_bn:
                for m in self.model.modules():
                    if isinstance(m, nn.BatchNorm2d):
                        m.eval()
                        if cfg.MODEL.freeze_bn_affine:
                            m.weight.requires_grad = False
                            m.bias.requires_grad = False
            for image, hm_target, wh_target, weights in self.train_ds:

                if epoch_num < 10:
                    ###excute warm up in the first epoch
                    if self.warup_step > 0:
                        if self.iter_num < self.warup_step:
                            for param_group in self.optimizer.param_groups:
                                param_group['lr'] = self.iter_num / float(
                                    self.warup_step) * self.init_lr
                                lr = param_group['lr']

                            logger.info('warm up with learning rate: [%f]' %
                                        (lr))

                start = time.time()

                if cfg.TRAIN.vis:
                    for i in range(image.shape[0]):

                        img = image[i].numpy()
                        img = np.transpose(img, axes=[1, 2, 0])
                        hm = hm_target[i].numpy()
                        wh = wh_target[i].numpy()

                        if cfg.DATA.use_int8_data:
                            hm = hm[:, :, 0].astype(np.uint8)
                            wh = wh[:, :, 0]
                        else:
                            hm = hm[:, :, 0].astype(np.float32)
                            wh = wh[:, :, 0].astype(np.float32)

                        cv2.namedWindow('s_hm', 0)
                        cv2.imshow('s_hm', hm)
                        cv2.namedWindow('s_wh', 0)
                        cv2.imshow('s_wh', wh + 1)
                        cv2.namedWindow('img', 0)
                        cv2.imshow('img', img)
                        cv2.waitKey(0)
                else:
                    data = image.to(self.device).float()

                    if cfg.DATA.use_int8_data:
                        hm_target = hm_target.to(
                            self.device).float() / cfg.DATA.use_int8_enlarge
                    else:
                        hm_target = hm_target.to(self.device).float()
                    wh_target = wh_target.to(self.device).float()
                    weights = weights.to(self.device).float()

                    batch_size = data.shape[0]

                    cls, wh = self.model(data)

                    cls_loss, wh_loss = self.criterion(
                        [cls, wh], [hm_target, wh_target, weights])

                    current_loss = cls_loss + wh_loss
                    summary_loss_cls.update(cls_loss.detach().item(),
                                            batch_size)
                    summary_loss_wh.update(wh_loss.detach().item(), batch_size)
                    self.optimizer.zero_grad()

                    if cfg.TRAIN.mix_precision:
                        with amp.scale_loss(current_loss,
                                            self.optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        current_loss.backward()

                    self.optimizer.step()
                    if cfg.TRAIN.ema:
                        self.ema.update()
                    self.iter_num += 1
                    time_cost_per_batch = time.time() - start

                    images_per_sec = cfg.TRAIN.batch_size * cfg.TRAIN.num_gpu / time_cost_per_batch

                    if self.iter_num % cfg.TRAIN.log_interval == 0:

                        log_message = '[TRAIN], '\
                                      'Epoch %d Step %d, ' \
                                      'summary_loss: %.6f, ' \
                                      'cls_loss: %.6f, '\
                                      'wh_loss: %.6f, ' \
                                      'time: %.6f, '\
                                      'speed %d images/persec'% (
                                          epoch_num,
                                          self.iter_num,
                                          summary_loss_cls.avg+summary_loss_wh.avg,
                                          summary_loss_cls.avg ,
                                          summary_loss_wh.avg,
                                          time.time() - start,
                                          images_per_sec)
                        logger.info(log_message)

                if cfg.TRAIN.SWA > 0 and epoch_num >= cfg.TRAIN.SWA:
                    self.optimizer.update_swa()

            return summary_loss_cls, summary_loss_wh

        def test_epoch(epoch_num):
            summary_loss_cls = AverageMeter()
            summary_loss_wh = AverageMeter()

            self.model.eval()
            t = time.time()
            with torch.no_grad():
                for step, (image, hm_target, wh_target,
                           weights) in enumerate(self.val_ds):

                    data = image.to(self.device).float()

                    if cfg.DATA.use_int8_data:
                        hm_target = hm_target.to(
                            self.device).float() / cfg.DATA.use_int8_enlarge
                    else:
                        hm_target = hm_target.to(self.device).float()

                    wh_target = wh_target.to(self.device).float()
                    weights = weights.to(self.device).float()
                    batch_size = data.shape[0]

                    with torch.no_grad():
                        cls, wh = self.model(data)

                    cls_loss, wh_loss = self.criterion(
                        [cls, wh], [hm_target, wh_target, weights])

                    summary_loss_cls.update(cls_loss.detach().item(),
                                            batch_size)
                    summary_loss_wh.update(wh_loss.detach().item(), batch_size)

                    if step % cfg.TRAIN.log_interval == 0:

                        log_message =   '[VAL], '\
                                        'Epoch %d Step %d, ' \
                                        'summary_loss: %.6f, ' \
                                        'cls_loss: %.6f, '\
                                        'wh_loss: %.6f, ' \
                                        'time: %.6f' % (epoch_num,
                                                        step,
                                                        summary_loss_cls.avg+summary_loss_wh.avg,
                                                        summary_loss_cls.avg,
                                                        summary_loss_wh.avg,
                                                        time.time() - t)

                        logger.info(log_message)

            return summary_loss_cls, summary_loss_wh

        for epoch in range(self.epochs):

            for param_group in self.optimizer.param_groups:
                lr = param_group['lr']
            logger.info('learning rate: [%f]' % (lr))
            t = time.time()

            summary_loss_cls, summary_loss_wh = train_epoch(epoch)

            train_epoch_log_message = '[centernet], '\
                                      '[RESULT]: Train. Epoch: %d,' \
                                      ' summary_loss: %.5f,' \
                                      ' cls_loss: %.6f, ' \
                                      ' wh_loss: %.6f, ' \
                                      ' time:%.5f' % (epoch,
                                                      summary_loss_cls.avg+summary_loss_wh.avg,
                                                      summary_loss_cls.avg,
                                                      summary_loss_wh.avg,
                                                      (time.time() - t))
            logger.info(train_epoch_log_message)

            if cfg.TRAIN.SWA > 0 and epoch >= cfg.TRAIN.SWA:

                ###switch to avg model
                self.optimizer.swap_swa_sgd()

            ##switch eam weighta
            if cfg.TRAIN.ema:
                self.ema.apply_shadow()

            if epoch % cfg.TRAIN.test_interval == 0:

                summary_loss_cls, summary_loss_wh = test_epoch(epoch)

                val_epoch_log_message = '[centernet], '\
                                        '[RESULT]: VAL. Epoch: %d,' \
                                        ' summary_loss: %.5f,' \
                                        ' cls_loss: %.6f, ' \
                                        ' wh_loss: %.6f, ' \
                                        ' time:%.5f' % (epoch,
                                                        summary_loss_cls.avg+summary_loss_wh.avg,
                                                        summary_loss_cls.avg,
                                                        summary_loss_wh.avg,
                                                        (time.time() - t))
                logger.info(val_epoch_log_message)

            self.scheduler.step()
            # self.scheduler.step(final_scores.avg)

            #### save model
            if not os.access(cfg.MODEL.model_path, os.F_OK):
                os.mkdir(cfg.MODEL.model_path)

            #### save the model every end of epoch
            current_model_saved_name = './model/centernet_epoch_%d_val_loss%.6f.pth' % (
                epoch, summary_loss_cls.avg + summary_loss_wh.avg)

            logger.info('A model saved to %s' % current_model_saved_name)
            torch.save(self.model.module.state_dict(),
                       current_model_saved_name)

            ####switch back
            if cfg.TRAIN.ema:
                self.ema.restore()

            if cfg.TRAIN.SWA > 0 and epoch > cfg.TRAIN.SWA:
                ###switch back to plain model to train next epoch
                self.optimizer.swap_swa_sgd()

    def load_weight(self):
        if cfg.MODEL.pretrained_model is not None:
            state_dict = torch.load(cfg.MODEL.pretrained_model,
                                    map_location=self.device)
            self.model.load_state_dict(state_dict, strict=False)
Exemplo n.º 14
0
def train(opt):
    if torch.cuda.is_available():
        # num_gpus = torch.cuda.device_count()
        device = 'cuda'
        torch.cuda.manual_seed(123)
        num_gpus = 1
    else:
        num_gpus = 1
        device = 'cpu'
        torch.manual_seed(123)

    training_params = {
        "batch_size": opt.batch_size * num_gpus,
        "shuffle": True,
        "drop_last": True,
        "collate_fn": collater_train,
        "num_workers": opt.num_worker,
        "pin_memory": True
    }

    test_params = {
        "batch_size": opt.batch_size * num_gpus,
        "shuffle": False,
        "drop_last": False,
        "collate_fn": collater_test,
        "num_workers": opt.num_worker,
        "pin_memory": True
    }

    train_dataset = VOCDetection(
        train=True,
        root=opt.train_dataset_root,
        transform=train_transform(
            width=EFFICIENTDET[opt.network]['input_size'],
            height=EFFICIENTDET[opt.network]['input_size'],
            lamda_norm=False))

    test_dataset = VOCDetection(
        train=False,
        root=opt.test_dataset_root,
        transform=transforms.Compose([
            Normalizer(lamda_norm=False, grey_p=0.0),
            Resizer(EFFICIENTDET[opt.network]['input_size'])
        ]))

    test_dataset_grey = VOCDetection(
        train=False,
        root=opt.test_dataset_root,
        transform=transforms.Compose([
            Normalizer(lamda_norm=False, grey_p=1.0),
            Resizer(EFFICIENTDET[opt.network]['input_size'])
        ]))

    train_generator = DataLoader(train_dataset, **training_params)
    test_generator = DataLoader(test_dataset, **test_params)
    test_grey_generator = DataLoader(test_dataset_grey, **test_params)

    network_id = int(''.join(filter(str.isdigit, opt.network)))
    loss_func = FocalLoss(alpha=opt.alpha,
                          gamma=opt.gamma,
                          smoothing_factor=opt.smoothing_factor)
    model = EfficientDet(MODEL_MAP[opt.network],
                         image_size=[
                             EFFICIENTDET[opt.network]['input_size'],
                             EFFICIENTDET[opt.network]['input_size']
                         ],
                         num_classes=train_dataset.num_classes(),
                         compound_coef=network_id,
                         num_anchors=9,
                         advprop=True,
                         from_pretrain=opt.from_pretrain)
    anchors_finder = Anchors()

    model.to(device)

    if opt.resume is not None:
        _ = resume(model, device, opt.resume)

    model = nn.DataParallel(model)

    if os.path.isdir(opt.log_path):
        shutil.rmtree(opt.log_path)

    os.makedirs(opt.log_path)

    if not os.path.isdir(opt.checkpoint_root_dir):
        os.makedirs(opt.checkpoint_root_dir)

    writer = SummaryWriter(opt.log_path)

    base_optimizer = torch.optim.Adam(model.parameters(),
                                      lr=opt.lr,
                                      weight_decay=opt.weight_decay,
                                      amsgrad=True)
    # optimizer = base_optimizer
    optimizer = SWA(base_optimizer, swa_start=10, swa_freq=5)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

    if opt.resume is not None:
        model.eval()

        loss_regression_ls = []
        loss_classification_ls = []
        with torch.no_grad():
            for iter, data in enumerate(tqdm(test_generator)):
                if torch.cuda.is_available():
                    anchors = anchors_finder(data['image'].cuda().float())
                    classification, regression = model(
                        data['image'].cuda().float())
                    cls_loss, reg_loss = loss_func(classification, regression,
                                                   anchors,
                                                   data['annots'].cuda())
                else:
                    anchors = anchors_finder(data['image'].float())
                    classification, regression = model(data['image'].float())
                    cls_loss, reg_loss = loss_func(classification, regression,
                                                   anchors, data['annots'])

                cls_loss = cls_loss.sum()
                reg_loss = reg_loss.sum()

                loss_classification_ls.append(float(cls_loss))
                loss_regression_ls.append(float(reg_loss))

        cls_loss = np.sum(loss_classification_ls) / test_dataset.__len__()
        reg_loss = np.sum(loss_regression_ls) / test_dataset.__len__()
        loss = (reg_loss + cls_loss) / 2

        writer.add_scalars('Total_loss', {'test': loss}, 0)
        writer.add_scalars('Regression_loss', {'test': reg_loss}, 0)
        writer.add_scalars('Classfication_loss (focal loss)',
                           {'test': cls_loss}, 0)

        print(
            'Epoch: {}/{}. Classification loss: {:1.5f}. Regression loss: {:1.5f}. Total loss: {:1.5f}'
            .format(0, opt.num_epochs, cls_loss, reg_loss, np.mean(loss)))

        mAP_1, _ = evaluate(test_generator,
                            model,
                            iou_threshold=0.5,
                            score_threshold=0.5)
        mAP_5, _ = evaluate(test_generator,
                            model,
                            iou_threshold=0.75,
                            score_threshold=0.1)

        writer.add_scalars(
            'mAP', {
                'score threshold 0.5; iou threshold {}'.format(0.5): mAP_1,
            }, 0)
        writer.add_scalars(
            'mAP', {
                'score threshold 0.1 ; iou threshold {}'.format(0.75): mAP_5,
            }, 0)

        mAP_1_grey, _ = evaluate(test_grey_generator,
                                 model,
                                 iou_threshold=0.5,
                                 score_threshold=0.5)
        mAP_5_grey, _ = evaluate(test_grey_generator,
                                 model,
                                 iou_threshold=0.75,
                                 score_threshold=0.1)

        writer.add_scalars(
            'mAP', {
                'grey: True; score threshold 0.5; iou threshold {}'.format(0.5):
                mAP_1_grey,
            }, 0)
        writer.add_scalars(
            'mAP', {
                'grey: True; score threshold 0.1 ; iou threshold {}'.format(0.75):
                mAP_5_grey,
            }, 0)

    model.train()

    num_iter_per_epoch = len(train_generator)
    train_iter = 0
    best_eval_loss = 10.0
    for epoch in range(opt.num_epochs):
        epoch_loss = []
        bn_update_data_list = []
        progress_bar = tqdm(train_generator)
        for iter, data in enumerate(progress_bar):
            scheduler.step(epoch + iter / train_generator.__len__())
            optimizer.zero_grad()

            if torch.cuda.is_available():
                if iter == 0:
                    bn_update_data_list.append(data['image'].float())
                anchors = anchors_finder(data['image'].cuda().float())
                classification, regression = model(
                    data['image'].cuda().float())
                cls_loss, reg_loss = loss_func(classification, regression,
                                               anchors, data['annots'].cuda())
            else:
                if iter == 0:
                    bn_update_data_list.append(data['image'].float())
                anchors = anchors_finder(data['image'].float())
                classification, regression = model(data['image'].float())
                cls_loss, reg_loss = loss_func(classification, regression,
                                               anchors, data['annots'])

            cls_loss = cls_loss.mean()
            reg_loss = reg_loss.mean()
            loss = (reg_loss + cls_loss) / 2

            if loss == 0:
                continue

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           opt.glip_threshold)
            optimizer.step()
            epoch_loss.append(float(loss))
            total_loss = np.mean(epoch_loss)
            train_iter += 1

            progress_bar.set_description(
                'Epoch: {}/{}. Iteration: {}/{}. Cls loss: {:.5f}. Reg loss: {:.5f}. loss: {:.5f} Total loss: {:.5f}'
                .format(epoch + 1, opt.num_epochs, iter + 1,
                        num_iter_per_epoch, cls_loss, reg_loss, loss,
                        total_loss))
            writer.add_scalars('Total_loss', {'train': total_loss}, train_iter)
            writer.add_scalars('Regression_loss', {'train': reg_loss},
                               train_iter)
            writer.add_scalars('Classfication_loss (focal loss)',
                               {'train': cls_loss}, train_iter)

        if (epoch + 1) % opt.test_interval == 0 and epoch + 1 >= 0:

            loss_regression_ls = []
            loss_classification_ls = []
            optimizer.swap_swa_sgd()
            optimizer.bn_update(bn_update_data_list, model)
            model.eval()

            with torch.no_grad():
                for iter, data in enumerate(tqdm(test_generator)):

                    if torch.cuda.is_available():
                        anchors = anchors_finder(data['image'].cuda().float())
                        classification, regression = model(
                            data['image'].cuda().float())
                        cls_loss, reg_loss = loss_func(classification,
                                                       regression, anchors,
                                                       data['annots'].cuda())

                    else:
                        anchors = anchors_finder(data['image'].float())
                        classification, regression = model(
                            data['image'].float())
                        cls_loss, reg_loss = loss_func(classification,
                                                       regression, anchors,
                                                       data['annots'])

                    cls_loss = cls_loss.sum()
                    reg_loss = reg_loss.sum()

                    loss_classification_ls.append(float(cls_loss))
                    loss_regression_ls.append(float(reg_loss))

            cls_loss = np.sum(loss_classification_ls) / test_dataset.__len__()
            reg_loss = np.sum(loss_regression_ls) / test_dataset.__len__()
            loss = (reg_loss + cls_loss) / 2

            writer.add_scalars('Total_loss', {'test': loss}, train_iter)
            writer.add_scalars('Regression_loss', {'test': reg_loss},
                               train_iter)
            writer.add_scalars('Classfication_loss (focal loss)',
                               {'test': cls_loss}, train_iter)

            print(
                'Epoch: {}/{}. Classification loss: {:1.5f}. Regression loss: {:1.5f}. Total loss: {:1.5f}'
                .format(epoch + 1, opt.num_epochs, cls_loss, reg_loss,
                        np.mean(loss)))

            if 0 < loss < best_eval_loss and not (epoch +
                                                  1) % opt.eval_interval == 0:
                best_eval_loss = loss

                mAP_1, _ = evaluate(test_generator,
                                    model,
                                    iou_threshold=0.5,
                                    score_threshold=0.5)
                mAP_5, _ = evaluate(test_generator,
                                    model,
                                    iou_threshold=0.75,
                                    score_threshold=0.1)

                writer.add_scalars('mAP', {
                    'score threshold 0.5; iou threshold {}'.format(0.5):
                    mAP_1,
                }, train_iter)
                writer.add_scalars('mAP', {
                    'score threshold 0.1 ; iou threshold {}'.format(0.75):
                    mAP_5,
                }, train_iter)

                mAP_1_grey, _ = evaluate(test_grey_generator,
                                         model,
                                         iou_threshold=0.5,
                                         score_threshold=0.5)
                mAP_5_grey, _ = evaluate(test_grey_generator,
                                         model,
                                         iou_threshold=0.75,
                                         score_threshold=0.1)

                writer.add_scalars(
                    'mAP', {
                        'grey: True; score threshold 0.5; iou threshold {}'.format(0.5):
                        mAP_1_grey,
                    }, train_iter)
                writer.add_scalars(
                    'mAP', {
                        'grey: True; score threshold 0.1 ; iou threshold {}'.format(0.75):
                        mAP_5_grey,
                    }, train_iter)

                if torch.cuda.device_count() > 1:
                    checkpoint_dict = {
                        'epoch': epoch + 1,
                        'state_dict': model.module.state_dict()
                    }
                    torch.save(
                        checkpoint_dict,
                        os.path.join(
                            opt.checkpoint_root_dir,
                            '{}_checpoint_epoch_{}_loss_{}.pth'.format(
                                'model_v1', epoch + 1, loss)))
                else:
                    checkpoint_dict = {
                        'epoch': epoch + 1,
                        'state_dict': model.state_dict()
                    }
                    torch.save(
                        checkpoint_dict,
                        os.path.join(
                            opt.checkpoint_root_dir,
                            '{}_checpoint_epoch_{}_loss_{}.pth'.format(
                                'model_v1', epoch + 1, loss)))

            if (epoch + 1) % opt.eval_interval == 0 and epoch + 1 >= 0:
                mAP_1, _ = evaluate(test_generator,
                                    model,
                                    iou_threshold=0.5,
                                    score_threshold=0.5)
                mAP_5, _ = evaluate(test_generator,
                                    model,
                                    iou_threshold=0.75,
                                    score_threshold=0.1)

                writer.add_scalars('mAP', {
                    'score threshold 0.5; iou threshold {}'.format(0.5):
                    mAP_1,
                }, train_iter)
                writer.add_scalars('mAP', {
                    'score threshold 0.1 ; iou threshold {}'.format(0.75):
                    mAP_5,
                }, train_iter)

                mAP_1_grey, _ = evaluate(test_grey_generator,
                                         model,
                                         iou_threshold=0.5,
                                         score_threshold=0.5)
                mAP_5_grey, _ = evaluate(test_grey_generator,
                                         model,
                                         iou_threshold=0.75,
                                         score_threshold=0.1)

                writer.add_scalars(
                    'mAP', {
                        'grey: True; score threshold 0.5; iou threshold {}'.format(0.5):
                        mAP_1_grey,
                    }, train_iter)
                writer.add_scalars(
                    'mAP', {
                        'grey: True; score threshold 0.1 ; iou threshold {}'.format(0.75):
                        mAP_5_grey,
                    }, train_iter)

                if torch.cuda.device_count() > 1:
                    checkpoint_dict = {
                        'epoch': epoch + 1,
                        'state_dict': model.module.state_dict()
                    }
                    torch.save(
                        checkpoint_dict,
                        os.path.join(
                            opt.checkpoint_root_dir,
                            '{}_checpoint_epoch_{}_loss_{}.pth'.format(
                                'model_v1', epoch + 1, loss)))
                else:
                    checkpoint_dict = {
                        'epoch': epoch + 1,
                        'state_dict': model.state_dict()
                    }
                    torch.save(
                        checkpoint_dict,
                        os.path.join(
                            opt.checkpoint_root_dir,
                            '{}_checpoint_epoch_{}_loss_{}.pth'.format(
                                'model_v1', epoch + 1, loss)))

            optimizer.swap_swa_sgd()

            model.train()

        scheduler.step()

    writer.close()
Exemplo n.º 15
0
	global_step = 0
	for epoch in trange(config['number_epochs']):
		model.train()
		train_bar = tqdm(train_loader)
		train_bar.set_description_str(desc=f"N epochs - {epoch}")

		for step, batch in enumerate(train_bar):
			global_step += 1
			image = batch['image'].to(device)
			label4class = batch['label0'].to(device)
			label = batch['label1'].to(device)

			output4class, output = model(image)
			loss4class = criterion4class(output4class, label4class)
			loss = criterion(output.squeeze(), label)
			swa.zero_grad()
			total_loss = loss4class*0.5 + loss*0.5
			total_loss.backward()
			swa.step()
			train_writer.add_scalar(tag="learning_rate", scalar_value=scheduler.get_lr()[0], global_step=global_step)
			train_writer.add_scalar(tag="BinaryLoss", scalar_value=loss.item(), global_step=global_step)
			train_writer.add_scalar(tag="SoftMaxLoss", scalar_value=loss4class.item(), global_step=global_step)
			train_bar.set_postfix_str(f"Loss = {loss.item()}")
			try:
				train_writer.add_scalar(tag="idrnd_score", scalar_value=idrnd_score_pytorch(label, output), global_step=global_step)
				train_writer.add_scalar(tag="far_score", scalar_value=far_score(label, output), global_step=global_step)
				train_writer.add_scalar(tag="frr_score", scalar_value=frr_score(label, output), global_step=global_step)
				train_writer.add_scalar(tag="accuracy", scalar_value=bce_accuracy(label, output), global_step=global_step)
			except Exception:
				pass
def train_model(model,criterion, optimizer, lr_scheduler,arc_model=None):

    train_dataset = gaodeDataset(opt.trainValConcat_dir, opt.train_list, phase='train', input_size=opt.input_size)
    trainloader = DataLoader(train_dataset,
                             batch_size=opt.train_batch_size,
                             shuffle=True,
                             num_workers=opt.num_workers)

    total_iters=len(trainloader)
    logger.info('total_iters:{}'.format(total_iters))
    model_name=opt.backbone
    train_loss = []
    since = time.time()
    best_model_wts = model.state_dict()
    best_score = 0.0
    model.train(True)
    logger.info('start training...')
    #
    if lr_scheduler is cos_lr_scheduler:
        return_lr_scheduler = lr_scheduler(optimizer, 4)
    if lr_scheduler is exp_lr_scheduler:
        return_lr_scheduler = lr_scheduler(optimizer)
    #
    optimizer=SWA(optimizer,swa_start=10, swa_freq=5)
    for epoch in range(1,opt.max_epoch+1):
        begin_time=time.time()
        #logger.info('learning rate:{}'.format(optimizer.param_groups[-1]['lr']))
        logger.info('Epoch {}/{}'.format(epoch, opt.max_epoch))
        logger.info('-' * 10)
        #optimizer = lr_scheduler(optimizer, epoch)
        running_loss = 0.0
        running_corrects_linear = 0
        running_corrects_arc=0
        count=0
        iters=len(trainloader)
        for i, data in enumerate(trainloader):
            count+=1
            inputs, labels = data
            labels = labels.type(torch.LongTensor)
            inputs, labels = inputs.cuda(), labels.cuda()
            #
            out_linear= model(inputs)
            _, linear_preds = torch.max(out_linear.data, 1)
            loss = criterion(out_linear, labels)
            #
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if lr_scheduler is cos_lr_scheduler:
                return_lr_scheduler.step(epoch + i / iters)

            if i % opt.print_interval == 0 or out_linear.size()[0] < opt.train_batch_size:
                spend_time = time.time() - begin_time
                logger.info(
                    ' Epoch:{}({}/{}) loss:{:.3f} lr:{:.7f} epoch_Time:{}min:'.format(
                        epoch, count, total_iters,
                        loss.item(),optimizer.param_groups[-1]['lr'],
                        spend_time / count * total_iters // 60 - spend_time // 60))
                train_loss.append(loss.item())
            running_corrects_linear += torch.sum(linear_preds == labels.data)
            #
        if lr_scheduler is exp_lr_scheduler:
            return_lr_scheduler.step()
        weight_score = val_model(model, criterion)
        epoch_acc_linear = running_corrects_linear.double() / total_iters / opt.train_batch_size
        logger.info('Epoch:[{}/{}] train_acc={:.3f} '.format(epoch, opt.max_epoch,
                                                                    epoch_acc_linear))
        save_dir = os.path.join(opt.checkpoints_dir, model_name)
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        model_out_path = save_dir + "/" + '{}_'.format(model_name) + str(epoch) + '.pth'
        best_model_out_path = save_dir + "/" + '{}_'.format(model_name) + 'best' + '.pth'
        #save the best model
        if weight_score > best_score:
            best_score = weight_score
            torch.save(model.state_dict(), best_model_out_path)
        #save based on epoch interval
        if epoch % opt.save_interval == 0 and epoch>opt.min_save_epoch:
            torch.save(model.state_dict(), model_out_path)
        #
        optimizer.swap_swa_sgd
    #
    logger.info('Best WeightF1: {:.3f}'.format(best_score))
    time_elapsed = time.time() - since
    logger.info('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
Exemplo n.º 17
0
    def train_model(self, model, criterion, lr_scheduler):

        self.logger.info('Using: {}'.format(self.model_name))
        self.logger.info('Using the GPU: {}'.format(self.gpu_id))
        self.logger.info('start training...')
        train_loss = []
        since = time.time()
        best_acc = 0.0
        model.train(True)
        base_optimizer = optim.SGD((model.parameters()),
                                   lr=self.lr,
                                   momentum=self.momentum,
                                   weight_decay=self.weight_decay)
        optimizer = SWA(base_optimizer, swa_start=10, swa_freq=5, swa_lr=0.001)
        # 余弦退火策略

        # if lr_scheduler is cos_lr_scheduler:
        #     return_lr_scheduler = lr_scheduler(optimizer, 5)
        # if lr_scheduler is exp_lr_scheduler:
        #     return_lr_scheduler = lr_scheduler(optimizer)

        for epoch in range(self.num_epochs):

            begin_time = time.time()
            data_loaders, dset_sizes = self.loaddata(
                train_dir=self.train_dir,
                batch_size=self.train_batch_size,
                shuffle=True,
                is_train=True)
            self.logger.info('-' * 10)
            self.logger.info('Epoch {}/{}'.format(epoch, self.num_epochs - 1))
            self.logger.info('learning rate:{}'.format(
                optimizer.param_groups[-1]['lr']))
            self.logger.info('-' * 10)
            running_loss = 0.0
            running_corrects = 0
            count = 0
            for i, data in enumerate(data_loaders):
                count += 1
                inputs, labels = data
                labels = labels.type(torch.LongTensor)
                inputs, labels = inputs.cuda(), labels.cuda()
                optimizer.zero_grad()
                _, outputs = model(inputs)
                loss = criterion(outputs, labels)
                _, preds = torch.max(outputs.data, 1)
                loss.backward()
                optimizer.step()
                # if epoch > 5 and (epoch-1) % 5 == 0:
                #     optimizer.update_swa()
                if i % self.print_interval == 0 or outputs.size(
                )[0] < self.train_batch_size:
                    spend_time = time.time() - begin_time
                    self.logger.info(' Epoch:{}({}/{}) loss:{:.3f} '.format(
                        epoch, count, dset_sizes // self.train_batch_size,
                        loss.item()))
                    train_loss.append(loss.item())
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

                # if lr_scheduler is exp_lr_scheduler:
                #     return_lr_scheduler.step()
                # if lr_scheduler is cos_lr_scheduler:
                #     return_lr_scheduler.step()

            val_acc = self.test_model(model, criterion)
            # self.train_infer(model, epoch)
            epoch_loss = running_loss / dset_sizes
            epoch_acc = running_corrects.double() / dset_sizes

            # self.logger.info('Epoch:[{}/{}]\t Loss={:.5f}\t Acc={:.3f} epoch_Time:{} min:'.format(epoch , self.num_epochs-1, epoch_loss, epoch_acc, spend_time/60))
            self.logger.info(
                'Epoch:[{}/{}] Loss={:.5f}  Acc={:.3f} Epoch_Time:{} min: ETA: {} hours'
                .format(epoch, self.num_epochs - 1, epoch_loss, epoch_acc,
                        spend_time / 60,
                        (self.num_epochs - epoch) * spend_time / 3600))
            if val_acc > best_acc and epoch > self.min_save_epoch:
                best_acc = val_acc
                best_model_wts = model.state_dict()
            if val_acc > 0.999:
                break
            save_dir = os.path.join(self.out_dir, self.model_name)
            model_out_path = save_dir + "/" + '{}_'.format(
                self.model_name) + str(epoch) + '.pth'
            torch.save(model.module.state_dict(), model_out_path)
        # save best model
        self.logger.info('Best Accuracy: {:.3f}'.format(best_acc))
        model.load_state_dict(best_model_wts)
        model_out_path = save_dir + "/" + '{}_best.pth'.format(self.model_name)
        torch.save(model, model_out_path)
        time_elapsed = time.time() - since
        self.logger.info('Training complete in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))
Exemplo n.º 18
0
def fit(
    model,
    train_dataset,
    val_dataset,
    optimizer_name="adam",
    samples_per_player=0,
    epochs=50,
    batch_size=32,
    val_bs=32,
    warmup_prop=0.1,
    lr=1e-3,
    acc_steps=1,
    swa_first_epoch=50,
    num_classes_aux=0,
    aux_mode="sigmoid",
    verbose=1,
    first_epoch_eval=0,
    device="cuda",
):
    """
    Fitting function for the classification task.

    Args:
        model (torch model): Model to train.
        train_dataset (torch dataset): Dataset to train with.
        val_dataset (torch dataset): Dataset to validate with.
        optimizer_name (str, optional): Optimizer name. Defaults to 'adam'.
        samples_per_player (int, optional): Number of images to use per player. Defaults to 0.
        epochs (int, optional): Number of epochs. Defaults to 50.
        batch_size (int, optional): Training batch size. Defaults to 32.
        val_bs (int, optional): Validation batch size. Defaults to 32.
        warmup_prop (float, optional): Warmup proportion. Defaults to 0.1.
        lr (float, optional): Learning rate. Defaults to 1e-3.
        acc_steps (int, optional): Accumulation steps. Defaults to 1.
        swa_first_epoch (int, optional): Epoch to start applying SWA from. Defaults to 50.
        num_classes_aux (int, optional): Number of auxiliary classes. Defaults to 0.
        aux_mode (str, optional): Mode for auxiliary classification. Defaults to 'sigmoid'.
        verbose (int, optional): Period (in epochs) to display logs at. Defaults to 1.
        first_epoch_eval (int, optional): Epoch to start evaluating at. Defaults to 0.
        device (str, optional): Device for torch. Defaults to "cuda".

    Returns:
        numpy array [len(val_dataset)]: Last predictions on the validation data.
        numpy array [len(val_dataset) x num_classes_aux]: Last aux predictions on the val data.
    """

    optimizer = define_optimizer(optimizer_name, model.parameters(), lr=lr)

    if swa_first_epoch <= epochs:
        optimizer = SWA(optimizer)

    loss_fct = nn.BCEWithLogitsLoss()
    loss_fct_aux = nn.BCEWithLogitsLoss(
    ) if aux_mode == "sigmoid" else nn.CrossEntropyLoss()
    aux_loss_weight = 1 if num_classes_aux else 0

    if samples_per_player:
        sampler = PlayerSampler(
            RandomSampler(train_dataset),
            train_dataset.players,
            batch_size=batch_size,
            drop_last=True,
            samples_per_player=samples_per_player,
        )
        train_loader = DataLoader(
            train_dataset,
            batch_sampler=sampler,
            num_workers=NUM_WORKERS,
            pin_memory=True,
        )

        print(
            f"Using {len(train_loader)} out of {len(train_dataset) // batch_size} "
            f"batches by limiting to {samples_per_player} samples per player.\n"
        )
    else:
        train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=NUM_WORKERS,
            pin_memory=True,
        )

    val_loader = DataLoader(
        val_dataset,
        batch_size=val_bs,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True,
    )

    num_training_steps = int(epochs * len(train_loader))
    num_warmup_steps = int(warmup_prop * num_training_steps)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps,
                                                num_training_steps)

    for epoch in range(epochs):
        model.train()

        start_time = time.time()
        optimizer.zero_grad()

        avg_loss = 0

        if epoch + 1 > swa_first_epoch:
            optimizer.swap_swa_sgd()

        for batch in train_loader:
            images = batch[0].to(device)
            y_batch = batch[1].to(device).view(-1).float()
            y_batch_aux = batch[2].to(device).float()
            y_batch_aux = y_batch_aux.float(
            ) if aux_mode == "sigmoid" else y_batch_aux.long()

            y_pred, y_pred_aux = model(images)

            loss = loss_fct(y_pred.view(-1), y_batch)
            if aux_loss_weight:
                loss += aux_loss_weight * loss_fct_aux(y_pred_aux, y_batch_aux)
            loss.backward()

            avg_loss += loss.item() / len(train_loader)
            optimizer.step()
            scheduler.step()
            for param in model.parameters():
                param.grad = None

        if epoch + 1 >= swa_first_epoch:
            optimizer.update_swa()
            optimizer.swap_swa_sgd()

        preds = np.empty(0)
        preds_aux = np.empty((0, num_classes_aux))
        model.eval()
        avg_val_loss, auc, scores_aux = 0., 0., 0.
        if epoch + 1 >= first_epoch_eval or epoch + 1 == epochs:
            with torch.no_grad():
                for batch in val_loader:
                    images = batch[0].to(device)
                    y_batch = batch[1].to(device).view(-1).float()
                    y_aux = batch[2].to(device).float()
                    y_batch_aux = y_aux.float(
                    ) if aux_mode == "sigmoid" else y_aux.long()

                    y_pred, y_pred_aux = model(images)

                    loss = loss_fct(y_pred.detach().view(-1), y_batch)
                    if aux_loss_weight:
                        loss += aux_loss_weight * loss_fct_aux(
                            y_pred_aux.detach(), y_batch_aux)

                    avg_val_loss += loss.item() / len(val_loader)

                    y_pred = torch.sigmoid(y_pred).view(-1)
                    preds = np.concatenate(
                        [preds, y_pred.detach().cpu().numpy()])

                    if num_classes_aux:
                        y_pred_aux = (y_pred_aux.sigmoid() if aux_mode
                                      == "sigmoid" else y_pred_aux.softmax(-1))
                        preds_aux = np.concatenate(
                            [preds_aux,
                             y_pred_aux.detach().cpu().numpy()])

            auc = roc_auc_score(val_dataset.labels, preds)

            if num_classes_aux:
                if aux_mode == "sigmoid":
                    scores_aux = np.round(
                        [
                            roc_auc_score(val_dataset.aux_labels[:, i],
                                          preds_aux[:, i])
                            for i in range(num_classes_aux)
                        ],
                        3,
                    ).tolist()
                else:
                    scores_aux = np.round(
                        [
                            roc_auc_score((val_dataset.aux_labels
                                           == i).astype(int), preds_aux[:, i])
                            for i in range(num_classes_aux)
                        ],
                        3,
                    ).tolist()
            else:
                scores_aux = 0

        elapsed_time = time.time() - start_time
        if (epoch + 1) % verbose == 0:
            elapsed_time = elapsed_time * verbose
            lr = scheduler.get_last_lr()[0]
            print(
                f"Epoch {epoch + 1:02d}/{epochs:02d} \t lr={lr:.1e}\t t={elapsed_time:.0f}s \t"
                f"loss={avg_loss:.3f}",
                end="\t",
            )

            if epoch + 1 >= first_epoch_eval:
                print(
                    f"val_loss={avg_val_loss:.3f} \t auc={auc:.3f}\t aucs_aux={scores_aux}"
                )
            else:
                print("")

    del val_loader, train_loader, y_pred
    torch.cuda.empty_cache()

    return preds, preds_aux
Exemplo n.º 19
0
def train(model_name, optim='adam'):
	train_dataset = IDRND_dataset_CV(fold=fold, mode=config['mode'],
									 add_idrnd_v1_dataset=True,
									 add_NUAA=False, aug=[0.5, 0.75, 0.25],
									 double_loss_mode=True, output_shape=config['image_resolution'])
	train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=8,
							  pin_memory=True, drop_last=True)

	val_dataset = IDRND_dataset_CV(fold=fold, mode=config['mode'].replace('train', 'val'),
								   double_loss_mode=True, output_shape=config['image_resolution'])
	val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=4, drop_last=False)

	if model_name == 'EF':
		model = DoubleLossModelTwoHead(base_model=EfficientNet.from_pretrained('efficientnet-b3')).to(device)
		model.load_state_dict(torch.load(f"../models_weights/pretrained/{model_name}_{8}_1.5062978111598622_0.9967353313006619.pth"))
	elif model_name == 'EFGAP':
		model = DoubleLossModelTwoHead(base_model=EfficientNetGAP.from_pretrained('efficientnet-b3')).to(device)
		model.load_state_dict(torch.load(f"../models_weights/pretrained/{model_name}_{8}_1.6058124488733547_1.0.pth"))

	criterion = FocalLoss(add_weight=False).to(device)
	criterion4class = CrossEntropyLoss().to(device)

	binary_weight = 0.5
	softmax_weight = 0.5

	if optim == 'adam':
		optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
	elif optim == 'sgd':
		optimizer = torch.optim.SGD(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay'])
	else:
		optimizer = torch.optim.SGD(model.parameters(), momentum=0.9, lr=config['learning_rate'], weight_decay=config['weight_decay'], nesterov=True)

	steps_per_epoch = train_loader.__len__()
	swa = SWA(optimizer, swa_start=config['swa_start'] * steps_per_epoch,
			  swa_freq=int(config['swa_freq'] * steps_per_epoch), swa_lr=config['learning_rate'] / 10)
	# scheduler = ExponentialLR(swa, gamma=0.9)
	scheduler = StepLR(swa, step_size=5*steps_per_epoch, gamma=0.5)

	global_step = 0
	for epoch in trange(config['number_epochs']):
		if epoch == 3:
			train_dataset = IDRND_dataset_CV(fold=fold, mode=config['mode'],
											 add_idrnd_v1_dataset=True,
											 add_NUAA=False,
											 double_loss_mode=True,
											 output_shape=config['image_resolution'],
											 aug=[0.0, 0.5, 0.0])
			train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=8,
									  pin_memory=True, drop_last=True)
		model.train()
		train_bar = tqdm(train_loader)
		train_bar.set_description_str(desc=f"N epochs - {epoch}")

		for step, batch in enumerate(train_bar):
			global_step += 1
			image = batch['image'].to(device)
			label4class = batch['label0'].to(device)
			label = batch['label1'].to(device)

			output4class, output = model(image)
			loss4class = criterion4class(output4class, label4class)
			loss = criterion(output.squeeze(), label)
			swa.zero_grad()
			total_loss = loss4class * softmax_weight + loss * binary_weight
			total_loss.backward()
			swa.step()
			train_writer.add_scalar(tag="learning_rate", scalar_value=scheduler.get_lr()[0], global_step=global_step)
			train_writer.add_scalar(tag="BinaryLoss", scalar_value=loss.item(), global_step=global_step)
			train_writer.add_scalar(tag="SoftMaxLoss", scalar_value=loss4class.item(), global_step=global_step)
			train_bar.set_postfix_str(f"Loss = {loss.item()}")
			try:
				train_writer.add_scalar(tag="idrnd_score", scalar_value=idrnd_score_pytorch(label, output),
										global_step=global_step)
				train_writer.add_scalar(tag="far_score", scalar_value=far_score(label, output), global_step=global_step)
				train_writer.add_scalar(tag="frr_score", scalar_value=frr_score(label, output), global_step=global_step)
				train_writer.add_scalar(tag="accuracy", scalar_value=bce_accuracy(label, output),
										global_step=global_step)
			except Exception:
				pass

		softmax_weight *= 0.92
		binary_weight *= 1.02

		if (epoch > config['swa_start'] and epoch % 2 == 0) or (epoch == config['number_epochs'] - 1):
			swa.swap_swa_sgd()
			swa.bn_update(train_loader, model, device)
			swa.swap_swa_sgd()

		scheduler.step()
		evaluate(model, val_loader, epoch, model_name)
Exemplo n.º 20
0
class Trainer(object):
    def __init__(self,
                 args,
                 surrogate,
                 train_data,
                 val_data,
                 tflogger=None,
                 pde=None):
        self.args = args
        self.pde = pde
        self.surrogate = surrogate
        self.tflogger = tflogger
        self.train_data = train_data
        self.val_data = val_data
        # self.init_transformations()
        self.train_loader = DataLoader(self.train_data,
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       pin_memory=True)
        self.val_loader = DataLoader(self.val_data,
                                     batch_size=args.batch_size,
                                     shuffle=False,
                                     pin_memory=True)
        self.init_optimizer()
        self.train_f_std = _cuda(torch.Tensor([[1.0]]))
        self.train_J_std = _cuda(torch.Tensor([[1.0]]))

    def init_transformations(self):
        # Init transformations corresponding to rotations, flips

        d = self.surrogate.fsm.vector_dim

        v2r = self.surrogate.fsm.vec_to_ring_map.cpu()

        v2r_perm = torch.argmax(v2r, dim=1).cpu().numpy()
        r2v_perm = torch.argmax(v2r, dim=0).cpu().numpy()

        # rotated = R * original
        p1 = np.array([i for i in range(d)])
        p2 = np.mod(p1 + int(d / 4), d)  # rotate 90
        p3 = np.mod(p2 + int(d / 4), d)  # rotate 180
        p4 = np.mod(p3 + int(d / 4), d)  # rotate 270

        # flip about 0th loc which is two points in 2d
        # TODO: make this dimension-agnostic
        def flip(p):
            p = p.reshape(-1, 2)
            p = np.concatenate(([p[0]], p[1:][::-1]), axis=0)
            return p.reshape(-1)

        ps = [p1, p2, p3, p4]  # , p2, p3, p4]
        # ps = ps + [flip(p) for p in ps]
        # ps = [p.tolist() for p in ps]

        self.perms = [v2r_perm[p[r2v_perm]].tolist() for p in ps]

        # trans_mats = [torch.eye(d)[p] for p in ps]
        # self.trans_mats = [torch.matmul(torch.matmul(
        #     self.surrogate.fsm.vec_to_ring_map,
        #     _cuda(m)), self.surrogate.fsm.vec_to_ring_map.t())
        #               for m in trans_mats]

        N = self.train_data.size()
        self.train_data.memory_size = self.train_data.memory_size * 8

        us_orig = torch.stack([td[0] for td in self.train_data.data])
        Js_orig = torch.stack([td[3] for td in self.train_data.data])
        Hs_orig = torch.stack([td[4] for td in self.train_data.data])

        for pn, perm in enumerate(self.perms):
            print("proc perm ", pn)
            print("perm: ", perm)
            us = us_orig[:, perm]
            Js = Js_orig[:, perm]

            Hs = Hs_orig[:, perm, :]
            Hs = Hs[:, :, perm]

            for i in range(len(us)):
                self.train_data.feed(
                    Example(
                        us[i].clone().detach(),
                        self.train_data.data[i][1].clone().detach(),
                        self.train_data.data[i][2].clone().detach(),
                        Js[i].clone().detach(),
                        Hs[i].clone().detach(),
                    ))
            """
            for n in range(N):
                if n % 1000 == 0:
                    print('preproc perm {}, n {}'.format(pn, n))
                u, p, f, J, H = self.train_data.data[n]
                u = u[perm]
                J = J[perm]
                H = H[[perm for _ in range(len(perm))],
                      [[i for i in range(len(perm))]
                      for _ in range(len(perm))]]
                H = H[[[i for i in range(len(perm))]
                       for _ in range(len(perm))],
                      [perm for _ in range(len(perm))]]

                self.train_data.feed(Example(u, p, f, J, H))
            """
            """
                ii = [[i for _ in range(self.surrogate.fsm.vector_dim)]
                        for i in range(len(u))]
                jjs = [self.perms[np.random.choice(len(self.perms))]
                        for _ in range(len(u))]

                u = u[iis, jjs]
                J = J[iis, jjs]

                iis = [[ii for _ in range(self.surrogate.fsm.vector_dim)]
                        for ii in iis]
                jjs = [[jj for _ in range(self.surrogate.fsm.vector_dim)]
                        for jj in jjs]
                kks = [[[k for k in range(self.surrogate.fsm.vector_dim)]
                        for _ in range(self.surrogate.fsm.vector_dim)]
                       for _ in range(len(u))]
                # pdb.set_trace()
                H = H[iis, jjs, kks]
                H = H[iis, kks, jjs]
            """
        # for tm in self.trans_mats:

    def init_optimizer(self):
        # Create optimizer if surrogate is trainable
        if hasattr(self.surrogate, "parameters"):
            if self.args.optimizer == "adam" or self.args.optimizer == "amsgrad":
                self.optimizer = torch.optim.AdamW(
                    (p
                     for p in self.surrogate.parameters() if p.requires_grad),
                    self.args.lr,
                    weight_decay=self.args.wd,
                    amsgrad=(self.args.optimizer == "amsgrad"),
                )
            elif self.args.optimizer == "sgd":
                self.optimizer = torch.optim.SGD(
                    (p
                     for p in self.surrogate.parameters() if p.requires_grad),
                    self.args.lr,
                    momentum=0.9,
                    weight_decay=self.args.wd,
                )
            elif self.args.optimizer == "radam":
                from ..util.radam import RAdam

                self.optimizer = RAdam(
                    (p
                     for p in self.surrogate.parameters() if p.requires_grad),
                    self.args.lr,
                    weight_decay=self.args.wd,
                    betas=ast.literal_eval(self.args.adam_betas),
                )
            else:
                raise Exception("Unknown optimizer")
            '''
            if self.args.fix_batch:
                self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                    self.optimizer,
                    patience=10
                    if self.args.fix_batch
                    else 20 * len(self.train_data) // self.args.batch_size,
                    verbose=self.args.verbose,
                    factor=1.0 / np.sqrt(np.sqrt(np.sqrt(2))),
                )
            else:
                """
                self.scheduler = torch.optim.lr_scheduler.MultiStepLR(
                    self.optimizer,
                    gamma=0.1,
                    milestones=[1e2,5e2,2e3,1e4,1e5])
                """
            '''
            if self.args.swa:
                from torchcontrib.optim import SWA

                self.optimizer = SWA(
                    self.optimizer,
                    swa_start=self.args.swa_start * len(self.train_loader),
                    swa_freq=len(self.train_loader),
                    swa_lr=self.args.lr,
                )

            else:
                self.scheduler = torch.optim.lr_scheduler.CyclicLR(
                    self.optimizer,
                    base_lr=self.args.lr * 1e-3,
                    max_lr=3 * self.args.lr,
                    step_size_up=int(math.ceil(len(self.train_loader) / 2)),
                    step_size_down=int(math.floor(len(self.train_loader) / 2)),
                    mode="triangular",
                    scale_fn=None,
                    scale_mode="cycle",
                    cycle_momentum=False,
                    base_momentum=0.8,
                    max_momentum=0.9,
                    last_epoch=-1,
                )
        else:
            self.optimizer = None

    def cd_step(self, step, batch):
        """Do a single step of CD training. Log stats to tensorboard."""
        self.surrogate.net.train()
        if self.optimizer:
            self.optimizer.zero_grad()
        u, p, f, J, H, _ = batch
        u, p, f, J, H = _cuda(u), _cuda(p), _cuda(f), _cuda(J), _cuda(H)
        u_sgld = torch.autograd.Variable(u, requires_grad=True)
        lam, eps, TEMP = (
            self.args.cd_sgld_lambda,
            self.args.cd_sgld_eps,
            self.args.cd_sgld_temp,
        )
        for i in range(self.args.cd_sgld_steps):
            temp = TEMP * self.args.cd_sgld_steps / (i + 1)
            u_sgld = (u_sgld - 0.5 * lam * torch.autograd.grad(
                torch.log(self.surrogate.f(u_sgld, p) + eps).sum() / temp,
                u_sgld)[0].clamp(-0.1, +0.1) +
                      lam * _cuda(torch.randn(*u.size())))

        eplus = torch.log(self.surrogate.f(u.detach(), p) + eps).mean()

        eminus = torch.log(self.surrogate.f(u_sgld.detach(), p) + eps).mean()
        cd_loss = eplus - eminus

        (cd_loss * self.args.cd_weight).backward()
        if self.args.clip_grad_norm is not None:
            torch.nn.utils.clip_grad_norm_(self.surrogate.net.parameters(),
                                           self.args.clip_grad_norm)
        self.optimizer.step()

        if self.tflogger is not None:
            self.tflogger.log_scalar("cd_E+_mean", eplus.item(), step)
            self.tflogger.log_scalar("cd_E-_mean", eminus.item(), step)
            self.tflogger.log_scalar("cd_loss", cd_loss.item(), step)

    def train_step(self, step, batch):
        """Do a single step of Sobolev training. Log stats to tensorboard."""
        self.surrogate.net.train()
        if self.optimizer:
            self.optimizer.zero_grad()
        u, p, f, J, H, _ = batch

        with Timer() as timer:
            u, p, f, J, H = _cuda(u), _cuda(p), _cuda(f), _cuda(J), _cuda(H)

        if self.args.poisson:
            u = self.surrogate.fsm.to_ring(u)
            u[:, 0] = 0.0
            u = self.surrogate.fsm.to_torch(u)

        # pdb.set_trace()
        # T = torch.stack([self.trans_mats[i]
        #                  for i in np.random.choice(len(self.trans_mats), size=len(u))])
        # pdb.set_trace()
        # u, J, H = (
        #     torch.matmul(u.unsqueeze(1), T).squeeze(),
        #     torch.matmul(J.unsqueeze(1), T).squeeze(),
        #     torch.matmul(torch.matmul(T.permute(0, 2, 1), H), T)
        # )
        """
        iis = [[i for _ in range(self.surrogate.fsm.vector_dim)]
                for i in range(len(u))]
        jjs = [self.perms[np.random.choice(len(self.perms))]
                for _ in range(len(u))]

        u = u[iis, jjs]
        J = J[iis, jjs]

        iis = [[ii for _ in range(self.surrogate.fsm.vector_dim)]
                for ii in iis]
        jjs = [[jj for _ in range(self.surrogate.fsm.vector_dim)]
                for jj in jjs]
        kks = [[[k for k in range(self.surrogate.fsm.vector_dim)]
                for _ in range(self.surrogate.fsm.vector_dim)]
               for _ in range(len(u))]
        # pdb.set_trace()
        H = H[iis, jjs, kks]
        H = H[iis, kks, jjs]
        """
        # pdb.set_trace()

        if self.tflogger is not None:
            self.tflogger.log_scalar("batch_cuda_time", timer.interval, step)

        with Timer() as timer:
            if self.args.hess:
                vectors = torch.randn(*J.size()).to(J.device)
                fhat, Jhat, Hvphat = self.surrogate.f_J_Hvp(u,
                                                            p,
                                                            vectors=vectors)
                Hvp = (vectors.view(*J.size(), 1) * H).sum(dim=1)
            else:
                fhat, Jhat = self.surrogate.f_J(u, p)
                Hvphat = torch.zeros_like(Jhat)
                Hvp = torch.zeros_like(Jhat)

        if not self.args.poisson:
            fhat[f.view(-1) < 0] *= 0.0
            Jhat[f.view(-1) < 0] *= 0.0
            Hvphat[f.view(-1) < 0] *= 0.0
            Hvp[f.view(-1) < 0] *= 0.0
            J[f.view(-1) < 0] *= 0.0
            f[f.view(-1) < 0] *= 0.0
        # pdb.set_trace()
        if self.tflogger is not None:
            self.tflogger.log_scalar("batch_forward_time", timer.interval,
                                     step)

        with Timer() as timer:
            f_loss, f_pce, J_loss, J_sim, H_loss, H_sim, total_loss = self.stats(
                step, u, f, J, Hvp, fhat, Jhat, Hvphat)
        if self.tflogger is not None:
            self.tflogger.log_scalar("stats_forward_time", timer.interval,
                                     step)

        if not np.isfinite(total_loss.data.cpu().numpy().sum()):
            pdb.set_trace()

        with Timer() as timer:
            if self.optimizer:
                total_loss.backward()
                if self.args.verbose:
                    log([
                        getattr(p.grad, "data",
                                torch.Tensor([0.0
                                              ])).norm().cpu().numpy().sum()
                        for p in self.surrogate.net.parameters()
                    ])
                if self.args.clip_grad_norm is not None:
                    torch.nn.utils.clip_grad_norm_(
                        self.surrogate.net.parameters(),
                        self.args.clip_grad_norm)
                self.optimizer.step()
                if not self.args.swa:
                    self.scheduler.step()
                if self.args.verbose:
                    log("lr: {}".format(self.optimizer.param_groups[0]["lr"]))
        if self.tflogger is not None:
            self.tflogger.log_scalar("backward_time", timer.interval, step)
        # pdb.set_trace()
        return (
            f_loss.item(),
            f_pce.item(),
            J_loss.item(),
            J_sim.item(),
            H_loss.item(),
            H_sim.item(),
            total_loss.item(),
        )

    def val_step(self, step):
        """Do a single validation step. Log stats to tensorboard."""
        self.surrogate.net.eval()
        for i, batch in enumerate(self.val_loader):
            u, p, f, J, H, _ = batch
            u, p, f, J, H = _cuda(u), _cuda(p), _cuda(f), _cuda(J), _cuda(H)

            if self.args.poisson:
                u = self.surrogate.fsm.to_ring(u)
                u[:, 0] = 0.0
                u = self.surrogate.fsm.to_torch(u)

            if self.args.hess:
                vectors = torch.randn(*J.size()).to(J.device)
                fhat, Jhat, Hvphat = self.surrogate.f_J_Hvp(u,
                                                            p,
                                                            vectors=vectors)
                Hvp = (vectors.view(*J.size(), 1) * H).sum(dim=1)
            else:
                fhat, Jhat = self.surrogate.f_J(u, p)
                Hvphat = torch.zeros_like(Jhat)
                Hvp = torch.zeros_like(Jhat)

            if not self.args.poisson:
                fhat[f.view(-1) < 0] *= 0.0
                Jhat[f.view(-1) < 0] *= 0.0
                Hvphat[f.view(-1) < 0] *= 0.0
                Hvp[f.view(-1) < 0] *= 0.0
                J[f.view(-1) < 0] *= 0.0
                f[f.view(-1) < 0] *= 0.0

            u_ = torch.cat([u_, u.data], dim=0) if i > 0 else u.data
            f_ = torch.cat([f_, f.data], dim=0) if i > 0 else f.data
            J_ = torch.cat([J_, J.data], dim=0) if i > 0 else J.data
            Hvp_ = torch.cat([Hvp_, Hvp.data], dim=0) if i > 0 else Hvp.data
            fhat_ = torch.cat([fhat_, fhat.data],
                              dim=0) if i > 0 else fhat.data
            Jhat_ = torch.cat([Jhat_, Jhat.data],
                              dim=0) if i > 0 else Jhat.data
            Hvphat_ = torch.cat([Hvphat_, Hvphat.data],
                                dim=0) if i > 0 else Hvphat.data

        return list(r.item() for r in self.stats(
            step, u_, f_, J_, Hvp_, fhat_, Jhat_, Hvphat_, phase="val"))

    def visualize(self, step, batch, dataset_name):
        u, p, f, J, H, _ = batch
        u, p, f, J = u[:16], p[:16], f[:16], J[:16]

        if self.args.poisson:
            u = self.surrogate.fsm.to_ring(u)
            u[:, 0] = 0.0
            u = self.surrogate.fsm.to_torch(u).cpu()

        fhat, Jhat = self.surrogate.f_J(u, p)

        assert len(u) <= 16
        assert len(u) == len(p)
        assert len(u) == len(f)
        assert len(u) == len(J)
        assert len(u) == len(fhat)
        assert len(u) == len(Jhat)

        u, f, J, fhat, Jhat = u.cpu(), f.cpu(), J.cpu(), fhat.cpu(), Jhat.cpu()

        cuda = self.surrogate.fsm.cuda
        self.surrogate.fsm.cuda = False

        fig, axes = plt.subplots(4, 4, figsize=(16, 16))
        axes = [ax for axs in axes for ax in axs]
        RCs = self.surrogate.fsm.ring_coords.cpu().detach().numpy()
        rigid_remover = RigidRemover(self.surrogate.fsm)
        for i in range(len(u)):
            ax = axes[i]
            locs = RCs + self.surrogate.fsm.to_ring(u[i]).detach().numpy()
            plot_boundary(
                lambda x: (0, 0),
                1000,
                label="reference, f={:.3e}".format(f[i].item()),
                ax=ax,
                color="k",
            )
            plot_boundary(
                self.surrogate.fsm.get_query_fn(u[i]),
                1000,
                ax=ax,
                label="ub, fhat={:.3e}".format(fhat[i].item()),
                linestyle="-",
                color="darkorange",
            )
            plot_boundary(
                self.surrogate.fsm.get_query_fn(
                    rigid_remover(u[i].unsqueeze(0)).squeeze(0)),
                1000,
                ax=ax,
                label="rigid removed",
                linestyle="--",
                color="blue",
            )
            if J is not None and Jhat is not None:
                J_ = self.surrogate.fsm.to_ring(J[i])
                Jhat_ = self.surrogate.fsm.to_ring(Jhat[i])
                normalizer = np.mean(
                    np.nan_to_num([
                        J_.norm(dim=1).detach().numpy(),
                        Jhat_.norm(dim=1).detach().numpy(),
                    ]))
                plot_vectors(
                    locs,
                    J_.detach().numpy(),
                    ax=ax,
                    label="J",
                    color="darkgreen",
                    normalizer=normalizer,
                    scale=1.0,
                )
                plot_vectors(
                    locs,
                    Jhat_.detach().numpy(),
                    ax=ax,
                    color="darkorchid",
                    label="Jhat",
                    normalizer=normalizer,
                    scale=1.0,
                )
                plot_vectors(
                    locs,
                    (J_ - Jhat_).detach().numpy(),
                    ax=ax,
                    color="red",
                    label="residual J-Jhat",
                    normalizer=normalizer,
                    scale=1.0,
                )

            ax.legend()
        fig.canvas.draw()
        buf = io.BytesIO()
        plt.savefig(buf, format="png")
        buf.seek(0)
        plt.close()
        if self.tflogger is not None:
            self.tflogger.log_images("{} displacements".format(dataset_name),
                                     [buf], step)

        if J is not None and Jhat is not None:
            fig, axes = plt.subplots(4, 4, figsize=(16, 16))
            axes = [ax for axs in axes for ax in axs]
            for i in range(len(J)):
                ax = axes[i]
                J_ = self.surrogate.fsm.to_ring(J[i])
                Jhat_ = self.surrogate.fsm.to_ring(Jhat[i])
                normalizer = 10 * np.mean(
                    np.nan_to_num([
                        J_.norm(dim=1).detach().numpy(),
                        Jhat_.norm(dim=1).detach().numpy(),
                    ]))
                plot_boundary(lambda x: (0, 0), 1000, label="reference", ax=ax)
                plot_boundary(
                    self.surrogate.fsm.get_query_fn(J[i]),
                    100,
                    ax=ax,
                    label="true_J, f={:.3e}".format(f[i].item()),
                    linestyle="--",
                    normalizer=normalizer,
                )
                plot_boundary(
                    self.surrogate.fsm.get_query_fn(Jhat[i]),
                    100,
                    ax=ax,
                    label="surrogate_J, fhat={:.3e}".format(fhat[i].item()),
                    linestyle="-.",
                    normalizer=normalizer,
                )
                ax.legend()
            fig.canvas.draw()
            buf = io.BytesIO()
            plt.savefig(buf, format="png")
            buf.seek(0)
            plt.close()
            if self.tflogger is not None:
                self.tflogger.log_images("{} Jacobians".format(dataset_name),
                                         [buf], step)
        self.surrogate.fsm.cuda = cuda

    def stats(self, step, u, f, J, Hvp, fhat, Jhat, Hvphat, phase="train"):
        """Take ground truth and predictions. Log stats and return loss."""

        if self.args.l1_loss:
            f_loss = torch.nn.functional.l1_loss(
                self.surrogate.scaler.scale(f + EPS, u),
                self.surrogate.scaler.scale(fhat + EPS, u),
            )
        else:
            f_loss = torch.nn.functional.mse_loss(
                self.surrogate.scaler.scale(f + EPS, u),
                self.surrogate.scaler.scale(fhat + EPS, u),
            )

        if self.args.angle_magnitude:
            J_loss = (self.args.mag_weight * torch.nn.functional.mse_loss(
                torch.log(J.norm(dim=1) + EPS),
                torch.log(Jhat.norm(dim=1) + EPS)) + 1.0 - similarity(J, Jhat))
            H_loss = (self.args.mag_weight * torch.nn.functional.mse_loss(
                torch.log(Hvp.norm(dim=1) + EPS),
                torch.log(Hvphat.norm(dim=1) + EPS),
            ) + 1.0 - similarity(Hvp, Hvphat))
        else:
            J_loss = torch.nn.functional.mse_loss(J, Jhat)
            H_loss = torch.nn.functional.mse_loss(Hvp, Hvphat)

        total_loss = f_loss + self.args.J_weight * J_loss + self.args.H_weight * H_loss

        f_pce = error_percent(f, fhat)
        J_pce = error_percent(J, Jhat)
        H_pce = error_percent(Hvp, Hvphat)

        J_sim = similarity(J, Jhat)
        H_sim = similarity(Hvp, Hvphat)

        if self.tflogger is not None:
            self.tflogger.log_scalar("train_set_size",
                                     len(self.train_data.data), step)
            self.tflogger.log_scalar("val_set_size", len(self.val_data.data),
                                     step)

            if hasattr(self.surrogate, "net") and hasattr(
                    self.surrogate.net, "parameters"):
                self.tflogger.log_scalar(
                    "param_norm_sum",
                    sum([
                        p.norm().sum().item()
                        for p in self.surrogate.net.parameters()
                    ]),
                    step,
                )
            self.tflogger.log_scalar("total_loss_" + phase, total_loss.item(),
                                     step)
            self.tflogger.log_scalar("f_loss_" + phase, f_loss.item(), step)
            self.tflogger.log_scalar("f_pce_" + phase, f_pce.item(), step)

            self.tflogger.log_scalar("f_mean_" + phase, f.mean().item(), step)
            self.tflogger.log_scalar("f_std_" + phase, f.std().item(), step)
            self.tflogger.log_scalar("fhat_mean_" + phase,
                                     fhat.mean().item(), step)
            self.tflogger.log_scalar("fhat_std_" + phase,
                                     fhat.std().item(), step)

            self.tflogger.log_scalar("J_loss_" + phase, J_loss.item(), step)
            self.tflogger.log_scalar("J_pce_" + phase, J_pce.item(), step)
            self.tflogger.log_scalar("J_sim_" + phase, J_sim.item(), step)

            self.tflogger.log_scalar("H_loss_" + phase, H_loss.item(), step)
            self.tflogger.log_scalar("H_pce_" + phase, H_pce.item(), step)
            self.tflogger.log_scalar("H_sim_" + phase, H_sim.item(), step)

            self.tflogger.log_scalar("J_mean_" + phase, J.mean().item(), step)
            self.tflogger.log_scalar("J_std_mean_" + phase,
                                     J.std(dim=1).mean().item(), step)
            self.tflogger.log_scalar("Jhat_mean_" + phase,
                                     Jhat.mean().item(), step)
            self.tflogger.log_scalar("Jhat_std_mean_" + phase,
                                     Jhat.std(dim=1).mean().item(), step)

        return (f_loss, f_pce, J_loss, J_sim, H_loss, H_sim, total_loss)
Exemplo n.º 21
0
    def train(self):
        # prepare data
        train_data = self.data('train')
        train_steps = int((len(train_data) + self.config.batch_size - 1) /
                          self.config.batch_size)
        train_dataloader = DataLoader(train_data,
                                      batch_size=self.config.batch_size,
                                      collate_fn=self.get_collate_fn('train'),
                                      shuffle=True,
                                      num_workers=2)

        # prepare optimizer
        params_lr = [{
            "params": self.model.bert_parameters,
            'lr': self.config.small_lr
        }, {
            "params": self.model.other_parameters,
            'lr': self.config.large_lr
        }]
        optimizer = torch.optim.Adam(params_lr)
        optimizer = SWA(optimizer)

        # prepare early stopping
        early_stopping = EarlyStopping(self.model,
                                       self.config.best_model_path,
                                       big_server=BIG_GPU,
                                       mode='max',
                                       patience=10,
                                       verbose=True)

        # prepare learning schedual
        learning_schedual = LearningSchedual(
            optimizer, self.config.epochs, train_steps,
            [self.config.small_lr, self.config.large_lr])

        # prepare other
        aux = REModelAux(self.config, train_steps)
        moving_log = MovingData(window=500)

        ending_flag = False
        # self.model.load_state_dict(torch.load(ROOT_SAVED_MODEL + 'temp_model.ckpt'))
        #
        # with torch.no_grad():
        #     self.model.eval()
        #     print(self.eval())
        #     return
        for epoch in range(0, self.config.epochs):
            for step, (inputs, y_trues,
                       spo_info) in enumerate(train_dataloader):
                inputs = [aaa.cuda() for aaa in inputs]
                y_trues = [aaa.cuda() for aaa in y_trues]
                if epoch > 0 or step == 1000:
                    self.model.detach_bert = False
                # train ================================================================================================
                preds = self.model(inputs)
                loss = self.calculate_loss(preds, y_trues, inputs[1],
                                           inputs[2])
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm(self.model.parameters(), 1)
                optimizer.step()

                with torch.no_grad():

                    logs = {'lr0': 0, 'lr1': 0}
                    if (epoch > 0 or step > 620) and False:
                        sbj_f1, spo_f1 = self.calculate_train_f1(
                            spo_info[0], preds, spo_info[1:3],
                            inputs[2].cpu().numpy())
                        metrics_data = {
                            'loss': loss.cpu().numpy(),
                            'sampled_num': 1,
                            'sbj_correct_num': sbj_f1[0],
                            'sbj_pred_num': sbj_f1[1],
                            'sbj_true_num': sbj_f1[2],
                            'spo_correct_num': spo_f1[0],
                            'spo_pred_num': spo_f1[1],
                            'spo_true_num': spo_f1[2]
                        }
                        moving_data = moving_log(epoch * train_steps + step,
                                                 metrics_data)
                        logs['loss'] = moving_data['loss'] / moving_data[
                            'sampled_num']
                        logs['sbj_precise'], logs['sbj_recall'], logs[
                            'sbj_f1'] = calculate_f1(
                                moving_data['sbj_correct_num'],
                                moving_data['sbj_pred_num'],
                                moving_data['sbj_true_num'],
                                verbose=True)
                        logs['spo_precise'], logs['spo_recall'], logs[
                            'spo_f1'] = calculate_f1(
                                moving_data['spo_correct_num'],
                                moving_data['spo_pred_num'],
                                moving_data['spo_true_num'],
                                verbose=True)
                    else:
                        metrics_data = {
                            'loss': loss.cpu().numpy(),
                            'sampled_num': 1
                        }
                        moving_data = moving_log(epoch * train_steps + step,
                                                 metrics_data)
                        logs['loss'] = moving_data['loss'] / moving_data[
                            'sampled_num']

                    # update lr
                    logs['lr0'], logs['lr1'] = learning_schedual.update_lr(
                        epoch, step)

                    if step == int(train_steps / 2) or step + 1 == train_steps:
                        self.model.eval()
                        torch.save(self.model.state_dict(),
                                   ROOT_SAVED_MODEL + 'temp_model.ckpt')
                        aux.new_line()
                        # dev ==========================================================================================
                        dev_result = self.eval()
                        logs['dev_loss'] = dev_result['loss']
                        logs['dev_sbj_precise'] = dev_result['sbj_precise']
                        logs['dev_sbj_recall'] = dev_result['sbj_recall']
                        logs['dev_sbj_f1'] = dev_result['sbj_f1']
                        logs['dev_spo_precise'] = dev_result['spo_precise']
                        logs['dev_spo_recall'] = dev_result['spo_recall']
                        logs['dev_spo_f1'] = dev_result['spo_f1']
                        logs['dev_precise'] = dev_result['precise']
                        logs['dev_recall'] = dev_result['recall']
                        logs['dev_f1'] = dev_result['f1']

                        # other thing
                        early_stopping(logs['dev_f1'])
                        if logs['dev_f1'] > 0.730:
                            optimizer.update_swa()

                        # test =========================================================================================
                        if (epoch + 1 == self.config.epochs and step + 1
                                == train_steps) or early_stopping.early_stop:
                            ending_flag = True
                            optimizer.swap_swa_sgd()
                            optimizer.bn_update(train_dataloader, self.model)
                            torch.save(self.model.state_dict(),
                                       ROOT_SAVED_MODEL + 'swa.ckpt')
                            self.test(ROOT_SAVED_MODEL + 'swa.ckpt')

                        self.model.train()
                aux.show_log(epoch, step, logs)
                if ending_flag:
                    return
Exemplo n.º 22
0
    def fit_unsupervised(self,
                         train_loader,
                         valid_loader,
                         epochs,
                         lr,
                         eval_freq=20,
                         print_freq=1000,
                         anneal=None,
                         fname=None,
                         opt_type='ADAM',
                         imp_sampling=False,
                         semi_amortized_steps=0,
                         nsamples=1):
        if opt_type == 'ADAM':
            opt = torch.optim.Adam(self.parameters(), lr=lr)
            if semi_amortized_steps > 0:
                try:
                    opt = torch.optim.Adam(self.parameters(), lr=lr)
                    inf_opt = torch.optim.Adam(self.inf_network.parameters(),
                                               lr=lr)
                except:
                    raise ValueError(
                        'Cannot optimize variational parameters for the selected model'
                    )

        elif opt_type == 'SWA':
            base_opt = torch.optim.Adam(self.parameters(), lr=lr)
            opt = SWA(base_opt, swa_start=100, swa_freq=50, swa_lr=lr)
            if semi_amortized_steps > 0:
                raise ValueError('not implemented for SWA')
        else:
            raise ValueError('bad opt type...')

        best_nelbo, best_nll, best_kl, best_ep = 100000, 100000, 100000, -1
        best_params = {}
        if fname is not None:
            logging.basicConfig(
                filename=fname[:-4] + '_loss.log',
                filemode='w',
                format='%(asctime)s - %(levelname)s \t %(message)s',
                level=logging.INFO)
        if anneal is None:
            anneal = epochs / 10
        for epoch in range(1, epochs + 1):
            anneal = min(1, epoch / (epochs * 0.5))
            self.train()
            batch_loss = 0
            idx = 0
            for data_tuples in train_loader:
                # Updates to inference network only
                if semi_amortized_steps > 0:
                    losslist = []
                    for k in range(semi_amortized_steps):
                        inf_opt.zero_grad()
                        _, loss = self.forward_unsupervised(*data_tuples,
                                                            anneal=1.)
                        loss.backward()
                        inf_opt.step()
                        losslist.append(loss.item())
                    if epoch % print_freq == 0:
                        print('Tightening variational params: ',
                              np.array(losslist))
                    anneal = 1.
                opt.zero_grad()
                if nsamples > 1:
                    dt = [
                        k.repeat(nsamples, 1) if k.dim() == 2 else k.repeat(
                            nsamples, 1, 1) for k in data_tuples
                    ]
                else:
                    dt = data_tuples
                _, loss = self.forward_unsupervised(*dt, anneal=anneal)
                loss.backward()
                """
                for n,p in self.named_parameters():
                    if np.any(np.isnan(p.grad.cpu().numpy())) or np.any(np.isinf(p.grad.cpu().numpy())):
                        print (n,'is nan or inf')
                        import ipdb;ipdb.set_trace()
                        print ('stop')
                """
                opt.step()
                idx += 1
                batch_loss += loss.item()
            if epoch % eval_freq == 0:
                self.eval()
                (nelbo, nll, kl, _), _ = self.forward_unsupervised(
                    *valid_loader.dataset.tensors, anneal=1.)
                if imp_sampling:
                    batch_nll = []
                    for i, valid_batch_loader in enumerate(valid_loader):
                        nll_estimate = self.imp_sampling(*valid_batch_loader,
                                                         nelbo,
                                                         anneal=1.)
                        nll_estimate = nll_estimate.item()
                        batch_nll.append(nll_estimate)
                    nll_estimate = np.mean(batch_nll)
                nelbo, nll, kl = nelbo.item(), nll.item(), kl.item()
                if nelbo < best_nelbo:
                    best_nelbo = nelbo
                    best_nll = nll
                    best_kl = kl
                    best_ep = epoch
                    if imp_sampling:
                        best_nll_estimate = nll_estimate
                    self.collect_best_params(best_params)
                    if fname is not None:
                        if opt_type == 'SWA':
                            opt.swap_swa_sgd()
                        torch.save(self.state_dict(), fname)
                self.train()
            if epoch % print_freq == 0:
                if imp_sampling:
                    print ('Ep',epoch,' Loss:',batch_loss/float(idx),', Anneal:', anneal, \
                    ', Best NELBO:%.3f, NLL est.:%.3f, NLL:%.3f, KL: %.3f @ epoch %d'%(best_nelbo, \
                        best_nll_estimate, best_nll, best_kl, best_ep))
                else:
                    print ('Ep',epoch,' Loss:',batch_loss/float(idx),', Anneal:', anneal, \
                        ', Best NELBO:%.3f, NLL:%.3f, KL: %.3f @ epoch %d'%(best_nelbo, \
                            best_nll, best_kl, best_ep))
                # print ('Ep',epoch,' Loss:',batch_loss/float(idx),', Anneal:', anneal, ', Best NELBO:%.3f NLL:%.3f, KL: %.3f @ epoch %d'%(best_nelbo, best_nll, best_kl, best_ep))
                if fname is not None:
                    msg = 'Ep: %d, Loss: %f, Anneal: %.3f, Best NELBO:%.3f NLL:%.3f, KL: %.3f @ epoch %d'
                    logging.info(msg, epoch, batch_loss / float(idx), anneal,
                                 best_nelbo, best_nll, best_kl, best_ep)
        print('Best NELBO:%.3f, NLL:%.3f, KL:%.3f@ epoch %d' %
              (best_nelbo, best_nll, best_kl, best_ep))
        self.best_params = best_params
        self.best_nelbo = best_nelbo
        self.best_nll = best_nll
        self.best_kl = best_kl
        self.best_ep = best_ep
        return best_params, best_nelbo, best_nll, best_kl, best_ep