Exemplo n.º 1
0
class Trainer():
    def __init__(self, config_path):
        self.image_config, self.model_config, self.run_config = LoadConfig(
            config_path=config_path).train_config()
        self.device = torch.device('cuda:%d' %
                                   self.run_config['device_ids'][0] if torch.
                                   cuda.is_available else 'cpu')
        self.model = getModel(self.model_config)
        os.makedirs(self.run_config['model_save_path'], exist_ok=True)
        self.run_config['num_workers'] = self.run_config['num_workers'] * len(
            self.run_config['device_ids'])
        self.train_set = Data(root=self.image_config['image_path'],
                              phase='train',
                              data_name=self.image_config['data_name'],
                              img_mode=self.image_config['image_mode'],
                              n_classes=self.model_config['num_classes'],
                              size=self.image_config['image_size'],
                              scale=self.image_config['image_scale'])
        self.valid_set = Data(root=self.image_config['image_path'],
                              phase='valid',
                              data_name=self.image_config['data_name'],
                              img_mode=self.image_config['image_mode'],
                              n_classes=self.model_config['num_classes'],
                              size=self.image_config['image_size'],
                              scale=self.image_config['image_scale'])
        self.className = self.valid_set.className
        self.train_loader = DataLoader(
            self.train_set,
            batch_size=self.run_config['batch_size'],
            shuffle=True,
            num_workers=self.run_config['num_workers'],
            pin_memory=True,
            drop_last=False)
        self.valid_loader = DataLoader(
            self.valid_set,
            batch_size=self.run_config['batch_size'],
            shuffle=True,
            num_workers=self.run_config['num_workers'],
            pin_memory=True,
            drop_last=False)
        train_params = self.model.parameters()
        self.optimizer = RAdam(train_params,
                               lr=eval(self.run_config['lr']),
                               weight_decay=eval(
                                   self.run_config['weight_decay']))
        if self.run_config['swa']:
            self.optimizer = SWA(self.optimizer,
                                 swa_start=10,
                                 swa_freq=5,
                                 swa_lr=0.005)
        # 设置学习率调节策略
        self.lr_scheduler = utils.adjustLR.AdjustLr(self.optimizer)
        if self.run_config['use_weight_balance']:
            weight = utils.weight_balance.getWeight(
                self.run_config['weights_file'])
        else:
            weight = None
        self.Criterion = SegmentationLosses(weight=weight,
                                            cuda=True,
                                            device=self.device,
                                            batch_average=False)
        self.metric = utils.metrics.MetricMeter(
            self.model_config['num_classes'])

    @logger.catch  # 在日志中记录错误
    def __call__(self):
        # 设置记录日志
        self.global_name = self.model_config['model_name']
        logger.add(os.path.join(
            self.image_config['image_path'], 'log',
            'log_' + self.global_name + '/train_{time}.log'),
                   format="{time} {level} {message}",
                   level="INFO",
                   encoding='utf-8')
        self.writer = SummaryWriter(logdir=os.path.join(
            self.image_config['image_path'], 'run', 'runs_' +
            self.global_name))
        logger.info("image_config: {} \n model_config: {} \n run_config: {}",
                    self.image_config, self.model_config, self.run_config)
        # 如果多余一张卡,就采用数据并行
        if len(self.run_config['device_ids']) > 1:
            self.model = nn.DataParallel(
                self.model, device_ids=self.run_config['device_ids'])
        self.model.to(device=self.device)
        cnt = 0
        # 如果有预训练模型就加载
        if self.run_config['pretrain'] != '':
            logger.info("loading pretrain %s" % self.run_config['pretrain'])
            try:
                self.load_checkpoint(use_optimizer=True,
                                     use_epoch=True,
                                     use_miou=True)
            except:
                print('load model with channed!!!!!')
                self.load_checkpoint_with_changed(use_optimizer=False,
                                                  use_epoch=False,
                                                  use_miou=False)
        logger.info("start training")

        for epoch in range(self.run_config['start_epoch'],
                           self.run_config['epoch']):
            lr = self.optimizer.param_groups[0]['lr']
            print('epoch=%d, lr=%.8f' % (epoch, lr))
            self.train_epoch(epoch, lr)
            valid_miou = self.valid_epoch(epoch)
            # 确定采用哪一种学习率调节策略
            self.lr_scheduler.LambdaLR_(milestone=5,
                                        gamma=0.92).step(epoch=epoch)
            self.save_checkpoint(epoch, valid_miou, 'last_' + self.global_name)
            if valid_miou > self.run_config['best_miou']:
                cnt = 0
                self.save_checkpoint(epoch, valid_miou,
                                     'best_' + self.global_name)
                logger.info("#############   %d saved   ##############" %
                            epoch)
                self.run_config['best_miou'] = valid_miou
            else:
                cnt += 1
                if cnt == self.run_config['early_stop']:
                    logger.info("early stop")
                    break
        self.writer.close()

    def train_epoch(self, epoch, lr):
        self.metric.reset()
        train_loss = 0.0
        train_miou = 0.0
        tbar = tqdm(self.train_loader)
        self.model.train()
        for i, (image, mask, edge) in enumerate(tbar):
            tbar.set_description('train_miou:%.6f' % train_miou)
            tbar.set_postfix({"train_loss": train_loss})
            image = image.to(self.device)
            mask = mask.to(self.device)
            edge = edge.to(self.device)
            self.optimizer.zero_grad()
            out = self.model(image)
            if isinstance(out, tuple):
                aux_out, final_out = out[0], out[1]
            else:
                aux_out, final_out = None, out
            if self.model_config['model_name'] == 'ocrnet':
                aux_loss = self.Criterion.build_loss(mode='rmi')(aux_out, mask)
                cls_loss = self.Criterion.build_loss(mode='ce')(final_out,
                                                                mask)
                loss = 0.4 * aux_loss + cls_loss
                loss = loss.mean()
            elif self.model_config['model_name'] == 'hrnet_duc':
                loss_body = self.Criterion.build_loss(
                    mode=self.run_config['loss_type'])(final_out, mask)
                loss_edge = self.Criterion.build_loss(mode='dice')(
                    aux_out.squeeze(), edge)
                loss = loss_body + loss_edge
                loss = loss.mean()
            else:
                loss = self.Criterion.build_loss(
                    mode=self.run_config['loss_type'])(final_out, mask)
            loss.backward()
            self.optimizer.step()
            if self.run_config['swa']:
                self.optimizer.swap_swa_sgd()
            with torch.no_grad():
                train_loss = ((train_loss * i) + loss.item()) / (i + 1)
                _, pred = torch.max(final_out, dim=1)
                self.metric.add(pred.cpu().numpy(), mask.cpu().numpy())
                train_miou, train_ious = self.metric.miou()
                train_fwiou = self.metric.fw_iou()
                train_accu = self.metric.pixel_accuracy()
                train_fwaccu = self.metric.pixel_accuracy_class()
        logger.info(
            "Epoch:%2d\t lr:%.8f\t Train loss:%.4f\t Train FWiou:%.4f\t Train Miou:%.4f\t Train accu:%.4f\t "
            "Train fwaccu:%.4f" % (epoch, lr, train_loss, train_fwiou,
                                   train_miou, train_accu, train_fwaccu))
        cls = ""
        ious = list()
        ious_dict = OrderedDict()
        for i, c in enumerate(self.className):
            ious_dict[c] = train_ious[i]
            ious.append(ious_dict[c])
            cls += "%s:" % c + "%.4f "
        ious = tuple(ious)
        logger.info(cls % ious)
        # tensorboard
        self.writer.add_scalar("lr", lr, epoch)
        self.writer.add_scalar("loss/train_loss", train_loss, epoch)
        self.writer.add_scalar("miou/train_miou", train_miou, epoch)
        self.writer.add_scalar("fwiou/train_fwiou", train_fwiou, epoch)
        self.writer.add_scalar("accuracy/train_accu", train_accu, epoch)
        self.writer.add_scalar("fwaccuracy/train_fwaccu", train_fwaccu, epoch)
        self.writer.add_scalars("ious/train_ious", ious_dict, epoch)

    def valid_epoch(self, epoch):
        self.metric.reset()
        valid_loss = 0.0
        valid_miou = 0.0
        tbar = tqdm(self.valid_loader)
        self.model.eval()
        with torch.no_grad():
            for i, (image, mask, edge) in enumerate(tbar):
                tbar.set_description('valid_miou:%.6f' % valid_miou)
                tbar.set_postfix({"valid_loss": valid_loss})
                image = image.to(self.device)
                mask = mask.to(self.device)
                edge = edge.to(self.device)
                out = self.model(image)
                if isinstance(out, tuple):
                    aux_out, final_out = out[0], out[1]
                else:
                    aux_out, final_out = None, out
                if self.model_config['model_name'] == 'ocrnet':
                    aux_loss = self.Criterion.build_loss(mode='rmi')(aux_out,
                                                                     mask)
                    cls_loss = self.Criterion.build_loss(mode='ce')(final_out,
                                                                    mask)
                    loss = 0.4 * aux_loss + cls_loss
                    loss = loss.mean()
                elif self.model_config['model_name'] == 'hrnet_duc':
                    loss_body = self.Criterion.build_loss(
                        mode=self.run_config['loss_type'])(final_out, mask)
                    loss_edge = self.Criterion.build_loss(mode='dice')(
                        aux_out.squeeze(), edge)
                    loss = loss_body + loss_edge
                    # loss = loss.mean()
                else:
                    loss = self.Criterion.build_loss(mode='ce')(final_out,
                                                                mask)
                valid_loss = ((valid_loss * i) + float(loss)) / (i + 1)
                _, pred = torch.max(final_out, dim=1)
                self.metric.add(pred.cpu().numpy(), mask.cpu().numpy())
                valid_miou, valid_ious = self.metric.miou()
                valid_fwiou = self.metric.fw_iou()
                valid_accu = self.metric.pixel_accuracy()
                valid_fwaccu = self.metric.pixel_accuracy_class()
            logger.info(
                "epoch:%d\t valid loss:%.4f\t valid fwiou:%.4f\t valid miou:%.4f valid accu:%.4f\t "
                "valid fwaccu:%.4f\t" % (epoch, valid_loss, valid_fwiou,
                                         valid_miou, valid_accu, valid_fwaccu))
            ious = list()
            cls = ""
            ious_dict = OrderedDict()
            for i, c in enumerate(self.className):
                ious_dict[c] = valid_ious[i]
                ious.append(ious_dict[c])
                cls += "%s:" % c + "%.4f "
            ious = tuple(ious)
            logger.info(cls % ious)
            self.writer.add_scalar("loss/valid_loss", valid_loss, epoch)
            self.writer.add_scalar("miou/valid_miou", valid_miou, epoch)
            self.writer.add_scalar("fwiou/valid_fwiou", valid_fwiou, epoch)
            self.writer.add_scalar("accuracy/valid_accu", valid_accu, epoch)
            self.writer.add_scalar("fwaccuracy/valid_fwaccu", valid_fwaccu,
                                   epoch)
            self.writer.add_scalars("ious/valid_ious", ious_dict, epoch)
        return valid_miou

    def save_checkpoint(self, epoch, best_miou, flag):
        meta = {
            'epoch': epoch,
            'model': self.model.state_dict(),
            'optim': self.optimizer.state_dict(),
            'bmiou': best_miou
        }
        try:
            torch.save(meta,
                       os.path.join(self.run_config['model_save_path'],
                                    '%s.pth' % flag),
                       _use_new_zipfile_serialization=False)
        except:
            torch.save(
                meta,
                os.path.join(self.run_config['model_save_path'],
                             '%s.pth' % flag))

    def load_checkpoint(self, use_optimizer, use_epoch, use_miou):
        state_dict = torch.load(self.run_config['pretrain'],
                                map_location=self.device)
        self.model.load_state_dict(state_dict['model'])
        if use_optimizer:
            self.optimizer.load_state_dict(state_dict['optim'])
        if use_epoch:
            self.run_config['start_epoch'] = state_dict['epoch'] + 1
        if use_miou:
            self.run_config['best_miou'] = state_dict['bmiou']

    def load_checkpoint_with_changed(self, use_optimizer, use_epoch, use_miou):
        state_dict = torch.load(self.run_config['pretrain'],
                                map_location=self.device)
        pretrain_dict = state_dict['model']
        model_dict = self.model.state_dict()
        pretrain_dict = {
            k: v
            for k, v in pretrain_dict.items()
            if k in model_dict and 'edge' not in k
        }
        model_dict.update(pretrain_dict)
        self.model.load_state_dict(model_dict)
        if use_optimizer:
            self.optimizer.load_state_dict(state_dict['optim'])
        if use_epoch:
            self.run_config['start_epoch'] = state_dict['epoch'] + 1
        if use_miou:
            self.run_config['best_miou'] = state_dict['bmiou']
Exemplo n.º 2
0
        checkpoint = torch.load(model_path)
        net.load_state_dict(checkpoint['model_state_dict'])
    net = net.to(device)

    # optimization
    # TODO: Choose an optimizer
    # optimizer = optim.Adam(filter(lambda x: x.requires_grad, net.parameters()), lr=args.learning_rate)
    scheduler = None
    if args.use_swa:
        steps_per_epoch = len(train_dataloader) // args.batch_size
        optimizer = SWA(optimizer, swa_start=20 * steps_per_epoch, swa_freq=steps_per_epoch)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer.optimizer, mode="max", patience=5, factor=0.5)

    if args.resume_dir:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        # best_pesq = checkpoint['pesq']
        best_loss = checkpoint['loss']
    else:
        start_epoch = 0
        best_loss = 1e8
        # best_pesq = 0.0
    # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=5)

    # add graph to tensorboard
    if args.add_graph:
        # TODO: Create a dummy input for your model
        # dummy = torch.randn(16, 1, args.hop_length * 16).to(device)
        writer.add_graph(net, dummy)
Exemplo n.º 3
0
            optimizer = optimizer_dict[optimizer_name](model.dmg_model.parameters(), lr=lr)
        else:
            optimizer = optimizer_dict[optimizer_name](model.parameters(), lr=lr)

    # Call

    print("Starting model training....")

    n_epochs = setting_dict['epochs']
    lr_patience = setting_dict['optimizer']['sheduler']['patience']
    lr_factor = setting_dict['optimizer']['sheduler']['factor']

    if weight_path is None:
        best_epoch = train(model,dataloaders,objective,optimizer,n_epochs,Path_list[1],Path_list[2], lr_patience=lr_patience,lr_factor=lr_factor, dice = False,seperate_loss=False, adabn = setting_dict["data"]["adabn_train"], own_sheduler = (not setting_dict["optimizer"]["longshedule"]))
    else:
        optimizer.load_state_dict(torch.load(weight_path)["optimizer"])
        best_epoch = train(model,dataloaders,objective,optimizer,n_epochs-torch.load(weight_path)["epoch"],Path_list[1],Path_list[2],start_epoch = torch.load(weight_path)["epoch"]+1, loss_dict=torch.load(weight_path)["loss_dict"], lr_patience=lr_patience,lr_factor=lr_factor, dice = False,seperate_loss=False, adabn = setting_dict["data"]["adabn_train"], own_sheduler = (not setting_dict["optimizer"]["longshedule"]))

    print("model training finished! yey!")

    if optimizer_name == "SWA":
        print ("Updating batch norm pars for SWA")
        train_dataset.dataset.SWA = True
        SWA_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=cpu_count)
        optimizer.swap_swa_sgd()
        optimizer.bn_update(SWA_loader, model, device='cuda')
        state = {
                'epoch': n_epochs,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'loss_dict': {}
Exemplo n.º 4
0
class Optimizer:
    optimizer_cls = None
    optimizer = None
    parameters = None

    def __init__(self,
                 gradient_clipping,
                 swa_start=None,
                 swa_freq=None,
                 swa_lr=None,
                 **kwargs):
        self.gradient_clipping = gradient_clipping
        self.optimizer_kwargs = kwargs
        self.swa_start = swa_start
        self.swa_freq = swa_freq
        self.swa_lr = swa_lr

    def set_parameters(self, parameters):
        self.parameters = tuple(parameters)
        self.optimizer = self.optimizer_cls(self.parameters,
                                            **self.optimizer_kwargs)
        if self.swa_start is not None:
            from torchcontrib.optim import SWA
            assert self.swa_freq is not None, self.swa_freq
            assert self.swa_lr is not None, self.swa_lr
            self.optimizer = SWA(self.optimizer,
                                 swa_start=self.swa_start,
                                 swa_freq=self.swa_freq,
                                 swa_lr=self.swa_lr)

    def check_if_set(self):
        assert self.optimizer is not None, \
            'The optimizer is not initialized, call set_parameter before' \
            ' using any of the optimizer functions'

    def zero_grad(self):
        self.check_if_set()
        return self.optimizer.zero_grad()

    def step(self):
        self.check_if_set()
        return self.optimizer.step()

    def swap_swa_sgd(self):
        self.check_if_set()
        from torchcontrib.optim import SWA
        assert isinstance(self.optimizer, SWA), self.optimizer
        return self.optimizer.swap_swa_sgd()

    def clip_grad(self):
        self.check_if_set()
        # Todo: report clipped and unclipped
        # Todo: allow clip=None but still report grad_norm
        grad_clips = self.gradient_clipping
        return torch.nn.utils.clip_grad_norm_(self.parameters, grad_clips)

    def to(self, device):
        if device is None:
            return
        self.check_if_set()
        for state in self.optimizer.state.values():
            for k, v in state.items():
                if torch.is_tensor(v):
                    state[k] = v.to(device)

    def cpu(self):
        return self.to('cpu')

    def cuda(self, device=None):
        assert device is None or isinstance(device, int), device
        if device is None:
            device = torch.device('cuda')
        return self.to(device)

    def load_state_dict(self, state_dict):
        self.check_if_set()
        return self.optimizer.load_state_dict(state_dict)

    def state_dict(self):
        self.check_if_set()
        return self.optimizer.state_dict()
Exemplo n.º 5
0
class Trainer(object):
    def __init__(self,
                 args,
                 train_dataloader=None,
                 validate_dataloader=None,
                 test_dataloader=None):
        self.args = args
        self.train_dataloader = train_dataloader
        self.validate_dataloader = validate_dataloader
        self.test_dataloader = test_dataloader

        self.label_lst = [i for i in range(self.args.num_classes)]
        self.num_labels = self.args.num_classes

        self.config_class = AutoConfig
        self.model_class = BertForSequenceClassification

        self.config = self.config_class.from_pretrained(
            self.args.bert_model_name,
            num_labels=self.num_labels,
            finetuning_task='nsmc',
            id2label={str(i): label
                      for i, label in enumerate(self.label_lst)},
            label2id={label: i
                      for i, label in enumerate(self.label_lst)})
        self.model = self.model_class.from_pretrained(
            self.args.bert_model_name, config=self.config)
        self.optimizer = None
        self.scheduler = None

        # GPU or CPU
        self.device = "cuda" if torch.cuda.is_available(
        ) and args.cuda else "cpu"
        self.model.to(self.device)

    def train(self, alpha, gamma):
        train_dataloader = self.train_dataloader

        t_total = len(train_dataloader) * self.args.num_epochs

        # Prepare optimizer and schedule (linear warmup and decay)
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in self.model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0
        }, {
            'params': [
                p for n, p in self.model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0
        }]

        if self.args.use_swa:
            base_opt = AdamW(optimizer_grouped_parameters,
                             lr=self.args.lr,
                             eps=1e-8)
            self.optimizer = SWA(base_opt,
                                 swa_start=4 * len(train_dataloader),
                                 swa_freq=100,
                                 swa_lr=5e-5)
            self.optimizer.param_groups = self.optimizer.optimizer.param_groups
            self.optimizer.state = self.optimizer.optimizer.state
            self.optimizer.defaults = self.optimizer.optimizer.defaults

        else:
            self.optimizer = optimizer = AdamW(optimizer_grouped_parameters,
                                               lr=self.args.lr,
                                               eps=1e-8)
        self.scheduler = scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=100,
            num_training_steps=self.args.num_epochs * len(train_dataloader))
        self.criterion = FocalLoss(alpha=alpha, gamma=gamma)

        # Train!
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d",
                    len(self.train_dataloader) * self.args.batch_size)
        logger.info("  Num Epochs = %d", self.args.num_epochs)
        logger.info("  Total train batch size = %d", self.args.batch_size)
        logger.info("  Total optimization steps = %d", t_total)

        global_step = 0
        tr_loss = 0.0
        self.model.zero_grad()
        self.optimizer.zero_grad()

        train_iterator = trange(int(self.args.num_epochs), desc="Epoch")

        fin_result = None
        f1_max = 0.0
        self.model.train()

        for epoch in train_iterator:
            epoch_iterator = tqdm(train_dataloader, desc="Iteration")
            for step, batch in enumerate(epoch_iterator):

                batch = tuple(t.to(self.device) for t in batch)  # GPU or CPU
                inputs = {
                    'input_ids': batch[0],
                    'attention_mask': batch[1],
                    'labels': batch[3],
                    'token_type_ids': batch[2]
                }

                # outputs = self.model(**inputs)
                # loss = outputs[0]

                # # Custom Loss
                loss, logits = self.model(**inputs)
                logits = torch.sigmoid(logits)

                labels = torch.zeros(
                    (len(batch[3]), self.num_labels)).to(self.device)
                labels[range(len(batch[3])), batch[3]] = 1

                loss = self.criterion(logits, labels)

                loss.backward()
                self.optimizer.step()
                self.scheduler.step()  # Update learning rate schedule

                self.model.zero_grad()
                self.optimizer.zero_grad()

                tr_loss += loss.item()
                global_step += 1
                logger.info('train loss %f', loss.item())

            logger.info('total train loss %f', tr_loss / global_step)
            if epoch >= 4 and self.args.use_swa:
                self.optimizer.swap_swa_sgd()

            fin_result = self.evaluate("validate")
            self.save_model(epoch)
            self.model.train()
            if epoch >= 4 and self.args.use_swa:
                self.optimizer.swap_swa_sgd()

            f1_max = max(fin_result['f1_macro'], f1_max)

        if epoch >= 4 and self.args.use_swa:
            self.optimizer.swap_swa_sgd()
        with open(os.path.join(self.args.base_dir, self.args.result_dir,
                               self.args.train_id, 'param_seach.txt'),
                  "a",
                  encoding="utf-8") as f:
            f.write('alpha: {}, gamma: {}, f1_macro: {}\n'.format(
                alpha, gamma, f1_max))
        return f1_max

    def evaluate(self, mode='test'):
        if mode == 'test':
            dataloader = self.test_dataloader
        elif mode == 'validate':
            dataloader = self.validate_dataloader
        else:
            raise Exception("Only dev and test dataset available")

        # Eval!
        logger.info("***** Running evaluation on %s dataset *****", mode)
        logger.info("  Num examples = %d",
                    len(dataloader) * self.args.batch_size)
        logger.info("  Batch size = %d", self.args.batch_size)
        eval_loss = 0.0
        nb_eval_steps = 0
        preds = None
        out_label_ids = None

        self.model.eval()

        for batch in tqdm(dataloader, desc="Evaluating"):
            batch = tuple(t.to(self.device) for t in batch)
            with torch.no_grad():
                inputs = {
                    'input_ids': batch[0],
                    'attention_mask': batch[1],
                    'labels': batch[3],
                    'token_type_ids': batch[2]
                }
                outputs = self.model(**inputs)
                tmp_eval_loss, logits = outputs[:2]

                eval_loss += tmp_eval_loss.mean().item()
            nb_eval_steps += 1

            if preds is None:
                preds = logits.detach().cpu().numpy()
                out_label_ids = inputs['labels'].detach().cpu().numpy()
            else:
                preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
                out_label_ids = np.append(
                    out_label_ids,
                    inputs['labels'].detach().cpu().numpy(),
                    axis=0)

        eval_loss = eval_loss / nb_eval_steps
        results = {"loss": eval_loss}

        preds = np.argmax(preds, axis=1)
        result = compute_metrics(preds, out_label_ids)
        results.update(result)

        p_macro, r_macro, f_macro, support_macro \
            = precision_recall_fscore_support(y_true=out_label_ids, y_pred=preds,
                                              labels=[i for i in range(self.num_labels)], average='macro')

        results.update({
            'precision': p_macro,
            'recall': r_macro,
            'f1_macro': f_macro
        })

        with open(self.args.prediction_file, "w", encoding="utf-8") as f:
            for pred in preds:
                f.write("{}\n".format(pred))

        if mode == 'validate':
            logger.info("***** Eval results *****")
            for key in sorted(results.keys()):
                logger.info("  %s = %s", key, str(results[key]))

        return results

    def save_model(self, num=0):
        state = {
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict()
        }
        torch.save(
            state,
            os.path.join(self.args.base_dir, self.args.result_dir,
                         self.args.train_id, 'epoch_' + str(num) + '.pth'))
        logger.info('model saved')

    def load_model(self, model_name):

        state = torch.load(
            os.path.join(self.args.base_dir, self.args.result_dir,
                         self.args.train_id, model_name))
        self.model.load_state_dict(state['model'])
        if self.optimizer is not None:
            self.optimizer.load_state_dict(state['optimizer'])
        if self.scheduler is not None:
            self.scheduler.load_state_dict(state['scheduler'])
        logger.info('model loaded')
Exemplo n.º 6
0
def main():

    maxIOU = 0.0
    assert torch.cuda.is_available()
    torch.backends.cudnn.benchmark = True
    model_fname = '../data/model_swa_8/deeplabv3_{0}_epoch%d.pth'.format(
        'crops')
    focal_loss = FocalLoss2d()
    train_dataset = CropSegmentation(train=True, crop_size=args.crop_size)
    #     test_dataset = CropSegmentation(train=False, crop_size=args.crop_size)

    model = torchvision.models.segmentation.deeplabv3_resnet50(
        pretrained=False, progress=True, num_classes=5, aux_loss=True)

    if args.train:
        weight = np.ones(4)
        weight[2] = 5
        weight[3] = 5
        w = torch.FloatTensor(weight).cuda()
        criterion = nn.CrossEntropyLoss()  #ignore_index=255 weight=w
        model = nn.DataParallel(model).cuda()

        for param in model.parameters():
            param.requires_grad = True

        optimizer1 = optim.SGD(model.parameters(),
                               lr=config.lr,
                               momentum=0.9,
                               weight_decay=1e-4)
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer,
                                                   T_max=(args.epochs // 9) +
                                                   1)
        optimizer = SWA(optimizer1)

        dataset_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=args.batch_size,
            shuffle=args.train,
            pin_memory=True,
            num_workers=args.workers)

        max_iter = args.epochs * len(dataset_loader)
        losses = AverageMeter()
        start_epoch = 0

        if args.resume:
            if os.path.isfile(args.resume):
                print('=> loading checkpoint {0}'.format(args.resume))
                checkpoint = torch.load(args.resume)
                start_epoch = checkpoint['epoch']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                print('=> loaded checkpoint {0} (epoch {1})'.format(
                    args.resume, checkpoint['epoch']))

            else:
                print('=> no checkpoint found at {0}'.format(args.resume))

        for epoch in range(start_epoch, args.epochs):
            scheduler.step(epoch)
            model.train()
            for i, (inputs, target) in enumerate(dataset_loader):

                inputs = Variable(inputs.cuda())
                target = Variable(target.cuda())
                outputs = model(inputs)
                loss1 = focal_loss(outputs['out'], target)
                loss2 = focal_loss(outputs['aux'], target)
                loss01 = loss1 + 0.1 * loss2
                loss3 = lovasz_softmax(outputs['out'], target)
                loss4 = lovasz_softmax(outputs['aux'], target)
                loss02 = loss3 + 0.1 * loss4
                loss = loss01 + loss02
                if np.isnan(loss.item()) or np.isinf(loss.item()):
                    pdb.set_trace()

                losses.update(loss.item(), args.batch_size)

                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                if i > 10 and i % 5 == 0:
                    optimizer.update_swa()

                print('epoch: {0}\t'
                      'iter: {1}/{2}\t'
                      'loss: {loss.val:.4f} ({loss.ema:.4f})'.format(
                          epoch + 1, i + 1, len(dataset_loader), loss=losses))

            if epoch > 5:
                torch.save(
                    {
                        'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                    }, model_fname % (epoch + 1))
        optimizer.swap_swa_sgd()
        torch.save(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, model_fname % (665 + 1))