Пример #1
0
def train(model, dataloader, scaler, optimizer, scheduler, device):
    pbar = ProgressBar(n_total=len(dataloader), desc='Training')
    train_loss = AverageMeter()
    model.train()
    for batch_idx, batch in enumerate(dataloader):
        # grab items from data loader and attach to GPU
        input_ids, attn_mask, start_pos, end_pos, token_type_ids = (
            batch['input_ids'].to(device), batch['attention_mask'].to(device),
            batch['start_positions'].to(device),
            batch['end_positions'].to(device),
            batch['token_type_ids'].to(device))
        # clear gradients
        optimizer.zero_grad()
        # use mixed precision
        with autocast():
            # forward
            out = model(input_ids=input_ids,
                        attention_mask=attn_mask,
                        start_positions=start_pos,
                        end_positions=end_pos,
                        token_type_ids=token_type_ids)
        # backward
        scaler.scale(out[0]).backward()  # out[0] = loss
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        pbar(step=batch_idx, info={'loss': train_loss.avg})
        train_loss.update(out[0].item(), n=1)
    train_log = {'train_loss': train_loss.avg}
    return train_log
Пример #2
0
    def train_valid(self, data, round_num):
        y, train_mask, valid_mask, test_mask, label_weights = data.y, data.train_mask, data.valid_mask, data.test_mask, data.label_weights
        patience = self.max_patience
        best_valid_score = 0
        valid_acc_meter = AverageMeter()
        for epoch in range(self.max_epochs):

            # train
            self.train()
            self.optimizer.zero_grad()
            preds = self.forward(data)
            loss = F.cross_entropy(preds[train_mask], y[train_mask], label_weights)
            loss.backward()
            self.optimizer.step()

            # valid
            self.eval()
            with torch.no_grad():
                preds = F.softmax(self.forward(data), dim=-1)
                valid_preds, test_preds = preds[valid_mask], preds[test_mask]
                valid_score = f1_score(y[valid_mask].cpu(), valid_preds.max(1)[1].flatten().cpu(), average='micro')
            valid_acc_meter.update(valid_score)
            # patience
            if valid_acc_meter.avg > best_valid_score:
                best_valid_score = valid_acc_meter.avg
                self.current_round_best_preds = test_preds
                patience = self.max_patience
            else:
                patience -= 1

            if patience == 0:
                break

        return best_valid_score
Пример #3
0
    def valid_epoch(self, data_loader):
        pbar = ProgressBar(n_total=len(data_loader), desc='Evaluating')
        self.entity_score.reset()
        valid_loss = AverageMeter()
        #output_file = jsonlines.open("data/case_out.jsonl","w")
        for step, batch in enumerate(data_loader):

            batch = tuple(t.to(self.device) for t in batch)

            input_ids, input_mask, trigger_mask, segment_ids, label_ids, input_lens, one_hot_labels = batch

            if not self.trigger:
                trigger_mask = None
            if not self.partial:
                one_hot_labels = None
            input_lens = input_lens.cpu().detach().numpy().tolist()
            self.model.eval()
            with torch.no_grad():
                features, loss = self.model(input_ids,
                                            segment_ids,
                                            input_mask,
                                            trigger_mask,
                                            label_ids,
                                            input_lens,
                                            one_hot_labels=one_hot_labels)
                tags, _ = self.model.crf._obtain_labels(
                    features, self.id2label, input_lens)
            valid_loss.update(val=loss.item(), n=input_ids.size(0))
            pbar(step=step, info={"loss": loss.item()})
            label_ids = label_ids.to('cpu').numpy().tolist()

            for i, label in enumerate(label_ids):
                temp_1 = []
                temp_2 = []
                for j, m in enumerate(label):
                    if j == 0:
                        continue
                    elif j == input_lens[i] - 1:  # 控制结束的位置
                        r = self.entity_score.update(pred_paths=[temp_2],
                                                     label_paths=[temp_1])
                        r["input_ids"] = input_ids[i, :].to(
                            "cpu").numpy().tolist()
                        #output_file.write(r)
                        break
                    else:
                        temp_1.append(self.id2label[label_ids[i][j]])
                        try:
                            temp_2.append(tags[i][j])
                        except Exception as e:
                            print(i, j)

        valid_info, ex_valid_info, class_info = self.entity_score.result()
        ex_info = {f'{key}': value for key, value in ex_valid_info.items()}
        info = {f'{key}': value for key, value in valid_info.items()}
        info['valid_loss'] = valid_loss.avg
        if 'cuda' in str(self.device):
            torch.cuda.empty_cache()
        return info, ex_info, class_info
        def start(self, train_loader, train_set, valid_set=None, valid_loader=None):
            self.train_num_batches = math.ceil(train_set.num_images / float(self.cf.train_batch_size))
            self.val_num_batches = 0 if valid_set is None else math.ceil(
                valid_set.num_images / float(self.cf.valid_batch_size))

            # Define early stopping control
            if self.cf.early_stopping:
                early_stopping = EarlyStopping(self.cf)
            else:
                early_stopping = None

            # Train process
            for epoch in tqdm(range(self.curr_epoch, self.cf.epochs + 1), desc='Training', file=sys.stdout):
                # Shuffle train data
                train_set.update_indexes()

                # Initialize logger
                self.logger_stats.write('\n\t ------ Epoch: ' + str(epoch) + ' ------ \n')

                # Initialize stats
                self.stats.epoch = epoch
                self.train_loss = AverageMeter()
                self.confm_list = np.zeros((self.cf.num_classes, self.cf.num_classes))

                # Train epoch
                self.training_loop(epoch, train_loader)

                # Save stats
                self.stats.train.conf_m = self.confm_list
                self.compute_stats(self.confm_list, self.train_loss)
                self.save_stats_epoch(epoch)
                self.logger_stats.write_stat(self.stats.train, epoch,
                                             os.path.join(self.cf.train_json_path,
                                                          'train_epoch_' + str(epoch) + '.json'))

                # Validate epoch
                self.validate_epoch(valid_set, valid_loader, early_stopping, epoch)

                # Update scheduler
                if self.model.scheduler is not None:
                    self.model.scheduler.step(self.stats.val.loss)

                # Saving model if score improvement
                new_best = self.model.save(self.stats)
                if new_best:
                    self.logger_stats.write_best_stats(self.stats, epoch, self.cf.best_json_file)

                if self.stop:
                    return

                    # Save model without training
            if self.cf.epochs == 0:
                self.model.save_model()
Пример #5
0
        def start(self,
                  valid_set,
                  valid_loader,
                  mode='Validation',
                  epoch=None,
                  global_bar=None,
                  save_folder=None):
            confm_list = np.zeros((self.cf.num_classes, self.cf.num_classes))

            self.val_loss = AverageMeter()

            # Initialize epoch progress bar
            val_num_batches = math.ceil(valid_set.num_images /
                                        float(self.cf.valid_batch_size))
            prev_msg = '\n' + mode + ' estimated time...\n'
            bar = ProgressBar(val_num_batches, lenBar=20)
            bar.set_prev_msg(prev_msg)
            bar.update(show=False)

            # Validate model
            if self.cf.problem_type == 'detection':
                self.validation_loop(epoch, valid_loader, valid_set, bar,
                                     global_bar, save_folder)
            else:
                self.validation_loop(epoch, valid_loader, valid_set, bar,
                                     global_bar, confm_list)

            # Compute stats
            self.compute_stats(np.asarray(self.stats.val.conf_m),
                               self.val_loss)

            # Save stats
            self.save_stats(epoch)
            if mode == 'Epoch Validation':
                self.logger_stats.write_stat(
                    self.stats.train, epoch,
                    os.path.join(self.cf.train_json_path,
                                 'valid_epoch_' + str(epoch) + '.json'))
            elif mode == 'Validation':
                self.logger_stats.write_stat(self.stats.val, epoch,
                                             self.cf.val_json_file)
            elif mode == 'Test':
                self.logger_stats.write_stat(self.stats.val, epoch,
                                             self.cf.test_json_file)
Пример #6
0
    def train_epoch(self, data_loader):

        pbar = ProgressBar(n_total=len(data_loader), desc='Training')
        tr_loss = AverageMeter()

        for step, batch in enumerate(data_loader):
            self.model.train()

            batch = tuple(t.to(self.device) for t in batch)

            input_ids, input_mask, trigger_mask, segment_ids, label_ids, input_lens, one_hot_labels = batch
            if not self.partial:
                one_hot_labels = None
            if not self.trigger:
                trigger_mask = None
            input_lens = input_lens.cpu().detach().numpy().tolist()
            _, loss, = self.model(input_ids, segment_ids, input_mask,
                                  trigger_mask, label_ids, input_lens,
                                  one_hot_labels)

            if len(self.n_gpu.split(",")) >= 2:
                loss = loss.mean()
            if self.gradient_accumulation_steps > 1:
                loss = loss / self.gradient_accumulation_steps

            loss.backward()
            clip_grad_norm_(self.model.parameters(), self.grad_clip)
            if (step + 1) % self.gradient_accumulation_steps == 0:
                self.optimizer.step()
                self.optimizer.zero_grad()
                self.global_step += 1
            tr_loss.update(loss.item(), n=1)
            self.tb_logger.scalar_summary("loss", tr_loss.avg,
                                          self.global_step)
            pbar(step=step, info={'loss': loss.item()})
            # if step%5==0:
            # 	self.logger.info("step:{},loss={:.4f}".format(self.global_step,loss.item()))

        info = {'loss': tr_loss.avg}
        if "cuda" in str(self.device):
            torch.cuda.empty_cache()
        return info
        def start(self, criterion, valid_set, valid_loader, epoch=None, global_bar=None):
            confm_list = np.zeros((self.cf.num_classes,self.cf.num_classes))

            val_loss = AverageMeter()

            # Initialize epoch progress bar
            val_num_batches = math.ceil(valid_set.num_images / float(self.cf.valid_batch_size))
            prev_msg = '\nValidation estimated time...\n'
            bar = ProgressBar(val_num_batches, lenBar=20)
            bar.set_prev_msg(prev_msg)
            bar.update(show=False)

            # Validate model
            for vi, data in enumerate(valid_loader):
                # Read data
                inputs, gts = data
                n_images,w,h,c = inputs.size()
                inputs = Variable(inputs, volatile=True).cuda()
                gts = Variable(gts, volatile=True).cuda()

                # Predict model
                outputs = self.model.net(inputs)
                predictions = outputs.data.max(1)[1].cpu().numpy()

                # Compute batch stats
                val_loss.update(criterion(outputs, gts).data[0] / n_images, n_images)
                confm = compute_confusion_matrix(predictions,gts.cpu().data.numpy(),self.cf.num_classes,self.cf.void_class)
                confm_list = map(operator.add, confm_list, confm)

                # Save epoch stats
                self.stats.val.conf_m = confm_list
                self.stats.val.loss = val_loss.avg / (w * h * c)

                # Update messages
                self.update_msg(bar, global_bar)

            # Compute stats
            self.compute_stats(np.asarray(self.stats.val.conf_m), val_loss)

            # Save stats
            self.save_stats(epoch)
Пример #8
0
        def start(self, valid_set, valid_loader, mode='Validation', epoch=None, save_folder=None):
            confm_list = np.zeros((self.cf.num_classes,self.cf.num_classes))

            self.val_loss = AverageMeter()

            # Validate model
            if self.cf.problem_type == 'detection':
                self.validation_loop(epoch, valid_loader, valid_set, save_folder)
            else:
                self.validation_loop(epoch, valid_loader, valid_set, confm_list)

            # Compute stats
            self.compute_stats(np.asarray(self.stats.val.conf_m), self.val_loss)

            # Save stats
            self.save_stats(epoch)
            if mode == 'Epoch Validation':
                self.logger_stats.write_stat(self.stats.train, epoch,
                                            os.path.join(self.cf.train_json_path,'valid_epoch_' + str(epoch) + '.json'))
            elif mode == 'Validation':
                self.logger_stats.write_stat(self.stats.val, epoch, self.cf.val_json_file)
            elif mode == 'Test':
                self.logger_stats.write_stat(self.stats.test, epoch, self.cf.test_json_file)
Пример #9
0
    class validation(object):
        def __init__(self, logger_stats, model, cf, stats, msg):
            # Initialize validation variables
            self.logger_stats = logger_stats
            self.model = model
            self.cf = cf
            self.stats = stats
            self.msg = msg
            self.writer = SummaryWriter(
                os.path.join(cf.tensorboard_path, 'validation'))

        def start(self,
                  valid_set,
                  valid_loader,
                  mode='Validation',
                  epoch=None,
                  global_bar=None,
                  save_folder=None):
            confm_list = np.zeros((self.cf.num_classes, self.cf.num_classes))

            self.val_loss = AverageMeter()

            # Initialize epoch progress bar
            val_num_batches = math.ceil(valid_set.num_images /
                                        float(self.cf.valid_batch_size))
            prev_msg = '\n' + mode + ' estimated time...\n'
            bar = ProgressBar(val_num_batches, lenBar=20)
            bar.set_prev_msg(prev_msg)
            bar.update(show=False)

            # Validate model
            if self.cf.problem_type == 'detection':
                self.validation_loop(epoch, valid_loader, valid_set, bar,
                                     global_bar, save_folder)
            else:
                self.validation_loop(epoch, valid_loader, valid_set, bar,
                                     global_bar, confm_list)

            # Compute stats
            self.compute_stats(np.asarray(self.stats.val.conf_m),
                               self.val_loss)

            # Save stats
            self.save_stats(epoch)
            if mode == 'Epoch Validation':
                self.logger_stats.write_stat(
                    self.stats.train, epoch,
                    os.path.join(self.cf.train_json_path,
                                 'valid_epoch_' + str(epoch) + '.json'))
            elif mode == 'Validation':
                self.logger_stats.write_stat(self.stats.val, epoch,
                                             self.cf.val_json_file)
            elif mode == 'Test':
                self.logger_stats.write_stat(self.stats.val, epoch,
                                             self.cf.test_json_file)

        def validation_loop(self, epoch, valid_loader, valid_set, bar,
                            global_bar, confm_list):
            for vi, data in enumerate(valid_loader):
                # Read data
                inputs, gts = data
                n_images, w, h, c = inputs.size()
                inputs = Variable(inputs).cuda()
                gts = Variable(gts).cuda()

                # Predict model
                with torch.no_grad():
                    outputs = self.model.net(inputs)
                    predictions = outputs.data.max(1)[1].cpu().numpy()

                    # Compute batch stats
                    self.val_loss.update(
                        float(
                            self.model.loss(outputs, gts).cpu().item() /
                            n_images), n_images)
                    confm = compute_confusion_matrix(predictions,
                                                     gts.cpu().data.numpy(),
                                                     self.cf.num_classes,
                                                     self.cf.void_class)
                    confm_list = map(operator.add, confm_list, confm)

                # Save epoch stats
                self.stats.val.conf_m = confm_list
                if not self.cf.normalize_loss:
                    self.stats.val.loss = self.val_loss.avg
                else:
                    self.stats.val.loss = self.val_loss.avg

                # Save predictions and generate overlaping
                self.update_tensorboard(
                    inputs.cpu(), gts.cpu(), predictions, epoch,
                    range(
                        vi * self.cf.valid_batch_size,
                        vi * self.cf.valid_batch_size +
                        np.shape(predictions)[0]), valid_set.num_images)

                # Update messages
                self.update_msg(bar, global_bar)

        def update_tensorboard(self, inputs, gts, predictions, epoch, indexes,
                               val_len):
            pass

        def update_msg(self, bar, global_bar):
            if global_bar == None:
                # Update progress bar
                bar.update()
            else:
                self.msg.eval_str = '\n' + bar.get_message(step=True)
                global_bar.set_msg(self.msg.accum_str + self.msg.last_str + self.msg.msg_stats_last + \
                                   self.msg.msg_stats_best + self.msg.eval_str)
                global_bar.update()

        def compute_stats(self, confm_list, val_loss):
            TP_list, TN_list, FP_list, FN_list = extract_stats_from_confm(
                confm_list)
            mean_accuracy = compute_accuracy(TP_list, TN_list, FP_list,
                                             FN_list)
            self.stats.val.acc = np.nanmean(mean_accuracy)
            self.stats.val.loss = val_loss.avg

        def save_stats(self, epoch):
            # Save logger
            if epoch is not None:
                self.logger_stats.write(
                    '----------------- Epoch scores summary ------------------------- \n'
                )
                self.logger_stats.write(
                    '[epoch %d], [val loss %.5f], [acc %.2f] \n' %
                    (epoch, self.stats.val.loss, 100 * self.stats.val.acc))
                self.logger_stats.write(
                    '---------------------------------------------------------------- \n'
                )
            else:
                self.logger_stats.write(
                    '----------------- Scores summary -------------------- \n')
                self.logger_stats.write(
                    '[val loss %.5f], [acc %.2f] \n' %
                    (self.stats.val.loss, 100 * self.stats.val.acc))
                self.logger_stats.write(
                    '---------------------------------------------------------------- \n'
                )
Пример #10
0
def main(args):
    writer = SummaryWriter(log_dir=args.tensorboard_log_dir)
    w, h = map(int, args.input_size.split(','))
    w_target, h_target = map(int, args.input_size_target.split(','))

    joint_transform = joint_transforms.Compose([
        joint_transforms.FreeScale((h, w)),
        joint_transforms.RandomHorizontallyFlip(),
    ])
    normalize = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*normalize),
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.ToPILImage()

    if '5' in args.data_dir:
        dataset = GTA5DataSetLMDB(
            args.data_dir,
            args.data_list,
            joint_transform=joint_transform,
            transform=input_transform,
            target_transform=target_transform,
        )
    else:
        dataset = CityscapesDataSetLMDB(
            args.data_dir,
            args.data_list,
            joint_transform=joint_transform,
            transform=input_transform,
            target_transform=target_transform,
        )
    loader = data.DataLoader(dataset,
                             batch_size=args.batch_size,
                             shuffle=True,
                             num_workers=args.num_workers,
                             pin_memory=True)
    val_dataset = CityscapesDataSetLMDB(
        args.data_dir_target,
        args.data_list_target,
        # joint_transform=joint_transform,
        transform=input_transform,
        target_transform=target_transform)
    val_loader = data.DataLoader(val_dataset,
                                 batch_size=args.batch_size,
                                 shuffle=True,
                                 num_workers=args.num_workers,
                                 pin_memory=True)

    upsample = nn.Upsample(size=(h_target, w_target),
                           mode='bilinear',
                           align_corners=True)

    net = resnet101_ibn_a_deeplab(args.model_path_prefix,
                                  n_classes=args.n_classes)
    # optimizer = get_seg_optimizer(net, args)
    optimizer = torch.optim.SGD(net.parameters(), args.learning_rate,
                                args.momentum)
    net = torch.nn.DataParallel(net)
    criterion = torch.nn.CrossEntropyLoss(size_average=False,
                                          ignore_index=args.ignore_index)

    num_batches = len(loader)
    for epoch in range(args.num_epoch):

        loss_rec = AverageMeter()
        data_time_rec = AverageMeter()
        batch_time_rec = AverageMeter()

        tem_time = time.time()
        for batch_index, batch_data in enumerate(loader):
            show_fig = (batch_index + 1) % args.show_img_freq == 0
            iteration = batch_index + 1 + epoch * num_batches

            # poly_lr_scheduler(
            #     optimizer=optimizer,
            #     init_lr=args.learning_rate,
            #     iter=iteration - 1,
            #     lr_decay_iter=args.lr_decay,
            #     max_iter=args.num_epoch*num_batches,
            #     power=args.poly_power,
            # )

            net.train()
            # net.module.freeze_bn()
            img, label, name = batch_data
            img = img.cuda()
            label_cuda = label.cuda()
            data_time_rec.update(time.time() - tem_time)

            output = net(img)
            loss = criterion(output, label_cuda)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_rec.update(loss.item())
            writer.add_scalar('A_seg_loss', loss.item(), iteration)
            batch_time_rec.update(time.time() - tem_time)
            tem_time = time.time()

            if (batch_index + 1) % args.print_freq == 0:
                print(
                    f'Epoch [{epoch+1:d}/{args.num_epoch:d}][{batch_index+1:d}/{num_batches:d}]\t'
                    f'Time: {batch_time_rec.avg:.2f}   '
                    f'Data: {data_time_rec.avg:.2f}   '
                    f'Loss: {loss_rec.avg:.2f}')
            if show_fig:
                base_lr = optimizer.param_groups[0]["lr"]
                output = torch.argmax(output, dim=1).detach()[0, ...].cpu()
                fig, axes = plt.subplots(2, 1, figsize=(12, 14))
                axes = axes.flat
                axes[0].imshow(colorize_mask(output.numpy()))
                axes[0].set_title(name[0])
                axes[1].imshow(colorize_mask(label[0, ...].numpy()))
                axes[1].set_title(f'seg_true_{base_lr:.6f}')
                writer.add_figure('A_seg', fig, iteration)

        mean_iu = test_miou(net, val_loader, upsample,
                            './ae_seg/dataset/info.json')
        torch.save(
            net.module.state_dict(),
            os.path.join(args.save_path_prefix,
                         f'{epoch:d}_{mean_iu*100:.0f}.pth'))

    writer.close()
Пример #11
0
    def train(self, src_loader, tgt_loader, val_loader, writer):
        tgt_domain_label = 1
        src_domain_label = 0
        num_batches = min(len(src_loader), len(tgt_loader))

        self.G_optimizer.param_groups[0]['lr'] *= 100
        self.D_optimizer.param_groups[0]['lr'] *= 100
        for epoch in range(self.opt.warm_up_epoch):
            for batch_index, batch_data in enumerate(zip(src_loader, tgt_loader)):
                self.G.train()
                src_batch, tgt_batch = batch_data
                src_img, _ = src_batch
                tgt_img, _ = tgt_batch
                src_img_cuda = src_img.cuda()
                tgt_img_cuda = tgt_img.cuda()

                rec_tgt = self.G(tgt_img_cuda)  # output [-1,1]
                rec_loss = self.mse_criterion(rec_tgt, tgt_img_cuda)
                self.G_optimizer.zero_grad()
                rec_loss.backward()
                self.G_optimizer.step()

                tgt_img_cuda = self.de_normalize(tgt_img_cuda).detach()
                tgt_D_loss = self.compute_discrim_loss(
                    tgt_img_cuda, tgt_domain_label
                )
                rec_D_loss = self.compute_discrim_loss(
                    src_img_cuda, src_domain_label
                )
                D_loss = tgt_D_loss + rec_D_loss
                self.D_optimizer.zero_grad()
                D_loss.backward()
                self.D_optimizer.step()

                if (batch_index+1) % self.opt.print_freq == 0:
                    print(
                        f'Warm Up Epoch [{epoch+1:d}/{self.opt.warm_up_epoch:d}]'
                        f'[{batch_index+1:d}/{num_batches:d}]\t'
                        f'G Loss: {rec_loss.item():.2f}   '
                        f'D Loss: {D_loss.item():.2f}'
                    )

        self.G_optimizer.param_groups[0]['lr'] /= 100
        self.D_optimizer.param_groups[0]['lr'] /= 100
        for epoch in range(self.opt.gen_epochs):

            content_loss_rec = AverageMeter()
            data_time_rec = AverageMeter()
            batch_time_rec = AverageMeter()

            tem_time = time.time()
            for batch_index, batch_data in enumerate(zip(src_loader, tgt_loader)):
                iteration = batch_index+1+epoch*num_batches

                self.G.train()
                src_batch, tgt_batch = batch_data
                src_img, _ = src_batch
                tgt_img, _ = tgt_batch
                src_img_cuda = src_img.cuda()
                tgt_img_cuda = tgt_img.cuda()
                data_time_rec.update(time.time()-tem_time)

                rec_tgt = self.G(tgt_img_cuda) # output [-1,1]
                if (batch_index+1) % self.opt.show_img_freq == 0:
                    rec_results = rec_tgt.detach().clone().cpu()
                # return to [0,1], for VGG takes input [0,1]
                rec_tgt = self.de_normalize(rec_tgt) 
                tgt_img_cuda = self.de_normalize(tgt_img_cuda)

                content_loss = self.compute_content_loss(rec_tgt, tgt_img_cuda)
                loss_style = content_loss * self.lambda_values[0]

                # adv train G
                for param in self.D.parameters():
                    param.requires_grad = False

                adv_tgt_rec_discrim_loss = self.compute_discrim_loss(
                    rec_tgt, src_domain_label
                )
                G_loss = loss_style +\
                         adv_tgt_rec_discrim_loss * self.lambda_values[1]

                self.G_optimizer.zero_grad()
                G_loss.backward()
                self.G_optimizer.step()

                # train D
                for param in self.D.parameters():
                    param.requires_grad = True
                rec_tgt = rec_tgt.detach()

                tgt_rec_discrim_loss = self.compute_discrim_loss(
                    rec_tgt, tgt_domain_label
                )
                tgt_discrim_loss = self.compute_discrim_loss(
                    tgt_img_cuda, tgt_domain_label
                )
                src_discrim_loss = self.compute_discrim_loss(
                    src_img_cuda, src_domain_label
                )
                D_loss = 0.5 * (tgt_rec_discrim_loss + tgt_discrim_loss) +\
                         src_discrim_loss

                self.D_optimizer.zero_grad()
                D_loss.backward()
                self.D_optimizer.step()

                content_loss_rec.update(content_loss.item())
                writer.add_scalar(
                    'content_loss', content_loss.item(), iteration
                )
                writer.add_scalar(
                    'G_loss', G_loss.item(), iteration
                )
                writer.add_scalar(
                    'D_loss', D_loss.item(), iteration
                )
                batch_time_rec.update(time.time()-tem_time)
                tem_time = time.time()

                if (batch_index+1) % self.opt.print_freq == 0:
                    print(
                        f'Epoch [{epoch+1:d}/{self.opt.gen_epochs:d}]'
                        f'[{batch_index+1:d}/{num_batches:d}]\t'
                        f'Time: {batch_time_rec.avg:.2f}   '
                        f'Data: {data_time_rec.avg:.2f}   '
                        f'Loss: {content_loss_rec.avg:.2f}'
                    )
                if (batch_index+1) % self.opt.show_img_freq == 0:
                    fig, axes = plt.subplots(2, 1, figsize=(12, 6))
                    axes = axes.flat
                    axes[0].imshow(self.to_image(rec_results[0, ...]))
                    label_new = self.compute_cls_label(rec_results)[0]
                    axes[0].set_title(f'rec_label_{label_new}')
                    axes[1].imshow(self.to_image(tgt_img[0, ...]))
                    label_ori = self.compute_cls_label(tgt_img)[0]
                    axes[1].set_title(f'ori_label_{label_ori}')
                    writer.add_figure('Gen', fig, iteration)

                if iteration % self.opt.checkpoint_freq == 0:

                    acc = cls_evaluate(self.combine_model, val_loader)

                    model_name = time.strftime(
                        '%m%d_%H%M_', time.localtime(time.time())
                    ) + str(iteration) + f'_{acc*1000:.0f}.pth'
                    torch.save(
                        self.G.module.state_dict(), 
                        os.path.join(self.opt.save_path_prefix, model_name)
                    )
                    print(f'Model saved as {model_name}')
Пример #12
0
def train_epoch(model, train_dataloader, optimizer, epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    cls_losses = AverageMeter()
    reg_losses = AverageMeter()
    iou_losses = AverageMeter()
    center_losses = AverageMeter()
    inner_losses = AverageMeter()
    end = time.time()

    model.train()
    optimizer.zero_grad()

    for iter, (vid_names, props_start_end, props_features,
               gt_start_end, query_tokens, query_len, props_num, num_frames) in enumerate(train_dataloader):

        data_time.update(time.time() - end)
        bs = props_features.size(0)

        box_lists, loss_dict = model(
            query_tokens, query_len, props_features, props_start_end, gt_start_end, props_num, num_frames
        )

        if args.is_second_stage:
            loss = loss_dict['loss_iou']
        else:
            loss = sum(loss for loss in loss_dict.values())

        losses.update(loss.item(), bs)
        cls_losses.update(loss_dict["loss_cls"].item(), bs)
        reg_losses.update(loss_dict["loss_reg"].item(), bs)
        iou_losses.update(loss_dict['loss_iou'].item(), bs)
        # center_losses.update(loss_dict["loss_centerness"].item(), bs)
        # inner_losses.update(loss_dict['loss_innerness'].item(), bs)

        # print(losses.avg)
        if loss != 0:
            loss.backward()

        if args.clip_gradient is not None:
            total_norm = clip_grad_norm_(model.parameters(), args.clip_gradient)
            # if total_norm > args.clip_gradient:
            #     logger.info("clipping gradient: {} with coef {}".format(total_norm, args.clip_gradient / total_norm))

        optimizer.step()
        optimizer.zero_grad()

        batch_time.update(time.time() - end)
        end = time.time()

        writer.add_scalar('train_data/loss', losses.val, epoch * len(train_dataloader) + iter + 1)
        writer.add_scalar('train_data/cls_loss', cls_losses.val, epoch * len(train_dataloader) + iter + 1)
        writer.add_scalar('train_data/reg_loss', reg_losses.val, epoch * len(train_dataloader) + iter + 1)
        writer.add_scalar('train_data/iou_loss', iou_losses.val, epoch * len(train_dataloader) + iter + 1)
        # writer.add_scalar('train_data/center_loss', center_losses.val, epoch * len(train_dataloader) + iter + 1)
        # writer.add_scalar('train_data/inner_loss', inner_losses.val, epoch * len(train_dataloader) + iter + 1)

        if iter % args.print_freq == 0:
            logger.info(
                'Epoch: [{0}][{1}/{2}]\t'
                'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(
                    epoch, iter, len(train_dataloader), batch_time=batch_time,data_time=data_time, loss=losses)
            )

    writer.add_scalar('train_epoch_data/epoch_loss', losses.avg, epoch)
    writer.add_scalar('train_epoch_data/epoch_cls_loss', cls_losses.avg, epoch)
    writer.add_scalar('train_epoch_data/epoch_reg_loss', reg_losses.avg, epoch)
    writer.add_scalar('train_epoch_data/epoch_iou_loss', iou_losses.avg, epoch)
    # writer.add_scalar('train_epoch_data/epoch_center_loss', center_losses.avg, epoch )
    # writer.add_scalar('train_epoch_data/epoch_inner_loss', inner_losses.avg, epoch)

    return losses.avg
Пример #13
0
def validate_epoch(trained_model, test_dataloader, epoch, word2idx, save_results=False):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    cls_losses = AverageMeter()
    reg_losses = AverageMeter()
    center_losses = AverageMeter()
    iou_losses = AverageMeter()
    end = time.time()

    trained_model.eval()
    results_dict = {}
    id2word = {idx: word for word, idx in word2idx.items()}

    with torch.no_grad():
        for iter, (vid_names, props_start_end, props_features,
                   gt_start_end, query_tokens, query_len, props_num, num_frames) in enumerate(test_dataloader):

            data_time.update(time.time() - end)
            bs = props_features.size(0)

            box_lists, loss_dict = trained_model(
                query_tokens, query_len, props_features, props_start_end, gt_start_end, props_num, num_frames
            )

            if args.is_second_stage:
                loss = loss_dict['loss_iou']
            else:
                loss = sum(loss for loss in loss_dict.values())

            losses.update(loss.item(), bs)
            cls_losses.update(loss_dict["loss_cls"].item(), bs)
            reg_losses.update(loss_dict["loss_reg"].item(), bs)
            iou_losses.update(loss_dict['loss_iou'].item(), bs)
            # center_losses.update(loss_dict["loss_centerness"].item(), bs)
            # inner_losses.update(loss_dict['loss_innerness'].item(), bs)

            batch_time.update(time.time() - end)
            end = time.time()

            if iter % args.print_freq == 0:
                logger.info(
                    'Epoch: [{0}][{1}/{2}]\t'
                    'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                    'Loss {loss.val:.4f} ({loss.avg:.4f})'.format(
                        epoch, iter, len(test_dataloader), batch_time=batch_time, data_time=data_time, loss=losses)
                )

            for i in range(bs):
                vid_name = vid_names[i]
                query_length = query_len[i]
                query = (' ').join(list(map(lambda x: id2word[x.item()], query_tokens[i, :query_length])))
                gt = gt_start_end[i].numpy().tolist()
                valid_props_num = props_num[i]

                per_vid_detections = box_lists[i]["detections"]
                per_vid_scores = box_lists[i]["scores"]
                per_vid_level = box_lists[i]['level']

                props_pred = torch.cat((per_vid_detections, per_vid_scores.unsqueeze(-1)), dim=-1)
                # edge_pred_info = edge_pred[i, :valid_props_num, :valid_props_num, :].permute(1, 2, 0).contiguous()
                temp_dict = {
                    'query': query,
                    'gt': gt,
                    'node_predictions': props_pred.cpu().numpy().tolist(),
                    'edge_predictions': props_pred.cpu().numpy().tolist(),
                    'level': per_vid_level
                }
                try:
                    results_dict[vid_name].append(temp_dict)
                except KeyError:
                    results_dict[vid_name] = []
                    results_dict[vid_name].append(temp_dict)

        writer.add_scalar('val_epoch_data/epoch_loss', losses.avg, epoch)
        writer.add_scalar('val_epoch_data/epoch_cls_loss', cls_losses.avg, epoch)
        writer.add_scalar('val_epoch_data/epoch_reg_loss', reg_losses.avg, epoch)
        writer.add_scalar('val_epoch_data/epoch_iou_loss', iou_losses.avg, epoch)
        # writer.add_scalar('val_epoch_data/epoch_center_loss', center_losses.avg, epoch)
        # writer.add_scalar('val_epoch_data/epoch_inner_loss', inner_losses.avg, epoch)

        if save_results:
            results_folder = './results/Evaluate/Raw_results'
            os.makedirs(results_folder, exist_ok=True)
            date = time.strftime('%Y-%m-%d-%H-%M', time.localtime(time.time()))
            json.dump(results_dict, open(os.path.join(results_folder, f'raw_results_{date}.json'), 'w'), indent=4)
        iou_topk_dict = {"iou": [0.5], 'topk': [1, 5]}
        postprocess_runner = PostProcessRunner(results_dict)
        topks, accuracy_topks = postprocess_runner.run_evaluate(iou_topk_dict=iou_topk_dict, temporal_nms=True)

    return losses.avg, topks, accuracy_topks
Пример #14
0
def main(args):
    writer = SummaryWriter(log_dir=args.tensorboard_log_dir)
    w, h = map(int, args.input_size.split(','))
    w_target, h_target = map(int, args.input_size_target.split(','))

    joint_transform = joint_transforms.Compose([
        joint_transforms.FreeScale((h, w)),
        joint_transforms.RandomHorizontallyFlip(),
    ])
    normalize = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*normalize),
    ])
    target_transform = extended_transforms.MaskToTensor()
    restore_transform = standard_transforms.ToPILImage()

    if '5' in args.data_dir:
        dataset = GTA5DataSetLMDB(
            args.data_dir, args.data_list,
            joint_transform=joint_transform,
            transform=input_transform, target_transform=target_transform,
        )
    else:
        dataset = CityscapesDataSetLMDB(
            args.data_dir, args.data_list,
            joint_transform=joint_transform,
            transform=input_transform, target_transform=target_transform,
        )
    loader = data.DataLoader(
        dataset, batch_size=args.batch_size,
        shuffle=True, num_workers=args.num_workers, pin_memory=True
    )
    val_dataset = CityscapesDataSetLMDB(
        args.data_dir_target, args.data_list_target,
        # joint_transform=joint_transform,
        transform=input_transform, target_transform=target_transform
    )
    val_loader = data.DataLoader(
        val_dataset, batch_size=args.batch_size,
        shuffle=False, num_workers=args.num_workers, pin_memory=True
    )


    upsample = nn.Upsample(size=(h_target, w_target),
                           mode='bilinear', align_corners=True)

    net = PSP(
        nclass = args.n_classes, backbone='resnet101', 
        root=args.model_path_prefix, norm_layer=BatchNorm2d,
    )

    params_list = [
        {'params': net.pretrained.parameters(), 'lr': args.learning_rate},
        {'params': net.head.parameters(), 'lr': args.learning_rate*10},
        {'params': net.auxlayer.parameters(), 'lr': args.learning_rate*10},
    ]
    optimizer = torch.optim.SGD(params_list,
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    criterion = SegmentationLosses(nclass=args.n_classes, aux=True, ignore_index=255)
    # criterion = SegmentationMultiLosses(nclass=args.n_classes, ignore_index=255)

    net = DataParallelModel(net).cuda()
    criterion = DataParallelCriterion(criterion).cuda()

    logger = utils.create_logger(args.tensorboard_log_dir, 'PSP_train')
    scheduler = utils.LR_Scheduler(args.lr_scheduler, args.learning_rate,
                                   args.num_epoch, len(loader), logger=logger,
                                   lr_step=args.lr_step)

    net_eval = Eval(net)

    num_batches = len(loader)
    best_pred = 0.0
    for epoch in range(args.num_epoch):

        loss_rec = AverageMeter()
        data_time_rec = AverageMeter()
        batch_time_rec = AverageMeter()

        tem_time = time.time()
        for batch_index, batch_data in enumerate(loader):
            scheduler(optimizer, batch_index, epoch, best_pred)
            show_fig = (batch_index+1) % args.show_img_freq == 0
            iteration = batch_index+1+epoch*num_batches

            net.train()
            img, label, name = batch_data
            img = img.cuda()
            label_cuda = label.cuda()
            data_time_rec.update(time.time()-tem_time)

            output = net(img)
            loss = criterion(output, label_cuda)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_rec.update(loss.item())
            writer.add_scalar('A_seg_loss', loss.item(), iteration)
            batch_time_rec.update(time.time()-tem_time)
            tem_time = time.time()

            if (batch_index+1) % args.print_freq == 0:
                print(
                    f'Epoch [{epoch+1:d}/{args.num_epoch:d}][{batch_index+1:d}/{num_batches:d}]\t'
                    f'Time: {batch_time_rec.avg:.2f}   '
                    f'Data: {data_time_rec.avg:.2f}   '
                    f'Loss: {loss_rec.avg:.2f}'
                )
            # if show_fig:
            #     # base_lr = optimizer.param_groups[0]["lr"]
            #     output = torch.argmax(output[0][0], dim=1).detach()[0, ...].cpu()
            #     # fig, axes = plt.subplots(2, 1, figsize=(12, 14))
            #     # axes = axes.flat
            #     # axes[0].imshow(colorize_mask(output.numpy()))
            #     # axes[0].set_title(name[0])
            #     # axes[1].imshow(colorize_mask(label[0, ...].numpy()))
            #     # axes[1].set_title(f'seg_true_{base_lr:.6f}')
            #     # writer.add_figure('A_seg', fig, iteration)
            #     output_mask = np.asarray(colorize_mask(output.numpy()))
            #     label = np.asarray(colorize_mask(label[0,...].numpy()))
            #     image_out = np.concatenate([output_mask, label])
            #     writer.add_image('A_seg', image_out, iteration)

        mean_iu = test_miou(net_eval, val_loader, upsample,
                            './style_seg/dataset/info.json')
        torch.save(
            net.module.state_dict(),
            os.path.join(args.save_path_prefix, f'{epoch:d}_{mean_iu*100:.0f}.pth')
        )

    writer.close()
Пример #15
0
def train(model, dataloader, scaler, optimizer, scheduler, device, args):
    pbar = ProgressBar(n_total=len(dataloader), desc='Training')
    train_loss = AverageMeter()
    train_acc = AverageMeter()
    train_f1 = AverageMeter()
    count = 0
    model.train()
    for batch_idx, batch in enumerate(dataloader):
        input_ids, attn_mask, token_type_ids, label, idx = (
            batch['input_ids'].to(device), batch['attention_mask'].to(device),
            batch['token_type_ids'].to(device), batch['labels'].to(device),
            batch['idx'].to(device))
        optimizer.zero_grad()
        with autocast():
            out = model(input_ids=input_ids.squeeze(1),
                        attention_mask=attn_mask.squeeze(1),
                        token_type_ids=token_type_ids.squeeze(1),
                        labels=label)

        if args.num_labels > 1:
            pred = out['logits'].argmax(dim=1, keepdim=True)
            correct = pred.eq(label.view_as(pred)).sum().item()
            f1 = f1_score(pred.cpu().numpy(),
                          label.cpu().numpy(),
                          average='weighted')
            train_f1.update(f1, n=input_ids.size(0))
            train_acc.update(correct, n=1)
        else:
            pred = out['logits']

        scaler.scale(out['loss']).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()
        count += input_ids.size(0)
        pbar(step=batch_idx, info={'loss': train_loss.avg})
        train_loss.update(out['loss'].item(), n=1)
    return {
        'loss': train_loss.avg,
        'acc': train_acc.sum / count,
        'f1': train_f1.avg
    }
Пример #16
0
        def start(self,
                  train_loader,
                  train_set,
                  valid_set=None,
                  valid_loader=None):
            self.train_num_batches = math.ceil(train_set.num_images /
                                               float(self.cf.train_batch_size))
            self.val_num_batches = 0 if valid_set is None else math.ceil(valid_set.num_images / \
                                                                    float(self.cf.valid_batch_size))
            # Define early stopping control
            if self.cf.early_stopping:
                early_Stopping = Early_Stopping(self.cf)
            else:
                early_Stopping = None

            prev_msg = '\nTotal estimated training time...\n'
            self.global_bar = ProgressBar(
                (self.cf.epochs + 1 - self.curr_epoch) *
                (self.train_num_batches + self.val_num_batches),
                lenBar=20)
            self.global_bar.set_prev_msg(prev_msg)

            # Train process
            for epoch in range(self.curr_epoch, self.cf.epochs + 1):
                # Shuffle train data
                train_set.update_indexes()

                # Initialize logger
                epoch_time = time.time()
                self.logger_stats.write('\t ------ Epoch: ' + str(epoch) +
                                        ' ------ \n')

                # Initialize epoch progress bar
                self.msg.accum_str = '\n\nEpoch %d/%d estimated time...\n' % \
                                     (epoch, self.cf.epochs)
                epoch_bar = ProgressBar(self.train_num_batches, lenBar=20)
                epoch_bar.update(show=False)

                # Initialize stats
                self.stats.epoch = epoch
                self.train_loss = AverageMeter()
                self.confm_list = np.zeros(
                    (self.cf.num_classes, self.cf.num_classes))

                # Train epoch
                self.training_loop(epoch, train_loader, epoch_bar)

                # Save stats
                self.stats.train.conf_m = self.confm_list
                self.compute_stats(np.asarray(self.confm_list),
                                   self.train_loss)
                self.save_stats_epoch(epoch)
                self.logger_stats.write_stat(
                    self.stats.train, epoch,
                    os.path.join(self.cf.train_json_path,
                                 'train_epoch_' + str(epoch) + '.json'))

                # Validate epoch
                self.validate_epoch(valid_set, valid_loader, early_Stopping,
                                    epoch, self.global_bar)

                # Update scheduler
                if self.model.scheduler is not None:
                    self.model.scheduler.step(self.stats.val.loss)

                # Saving model if score improvement
                new_best = self.model.save(self.stats)
                if new_best:
                    self.logger_stats.write_best_stats(self.stats, epoch,
                                                       self.cf.best_json_file)

                # Update display values
                self.update_messages(epoch, epoch_time, new_best)

                if self.stop:
                    return

            # Save model without training
            if self.cf.epochs == 0:
                self.model.save_model()
    class train:
        def __init__(self, logger_stats, model, cf, validator, stats, msg):
            # Initialize training variables
            self.logger_stats = logger_stats
            self.model = model
            self.cf = cf
            self.validator = validator
            self.logger_stats.write('\n- Starting train <--- \n')
            self.curr_epoch = 1 if self.model.best_stats.epoch == 0 else self.model.best_stats.epoch
            self.stop = False
            self.stats = stats
            self.best_acc = 0
            self.msg = msg
            self.loss = None
            self.outputs = None
            self.labels = None
            self.writer = SummaryWriter(os.path.join(cf.tensorboard_path, 'train'))

        def start(self, train_loader, train_set, valid_set=None, valid_loader=None):
            self.train_num_batches = math.ceil(train_set.num_images / float(self.cf.train_batch_size))
            self.val_num_batches = 0 if valid_set is None else math.ceil(
                valid_set.num_images / float(self.cf.valid_batch_size))

            # Define early stopping control
            if self.cf.early_stopping:
                early_stopping = EarlyStopping(self.cf)
            else:
                early_stopping = None

            # Train process
            for epoch in tqdm(range(self.curr_epoch, self.cf.epochs + 1), desc='Training', file=sys.stdout):
                # Shuffle train data
                train_set.update_indexes()

                # Initialize logger
                self.logger_stats.write('\n\t ------ Epoch: ' + str(epoch) + ' ------ \n')

                # Initialize stats
                self.stats.epoch = epoch
                self.train_loss = AverageMeter()
                self.confm_list = np.zeros((self.cf.num_classes, self.cf.num_classes))

                # Train epoch
                self.training_loop(epoch, train_loader)

                # Save stats
                self.stats.train.conf_m = self.confm_list
                self.compute_stats(self.confm_list, self.train_loss)
                self.save_stats_epoch(epoch)
                self.logger_stats.write_stat(self.stats.train, epoch,
                                             os.path.join(self.cf.train_json_path,
                                                          'train_epoch_' + str(epoch) + '.json'))

                # Validate epoch
                self.validate_epoch(valid_set, valid_loader, early_stopping, epoch)

                # Update scheduler
                if self.model.scheduler is not None:
                    self.model.scheduler.step(self.stats.val.loss)

                # Saving model if score improvement
                new_best = self.model.save(self.stats)
                if new_best:
                    self.logger_stats.write_best_stats(self.stats, epoch, self.cf.best_json_file)

                if self.stop:
                    return

                    # Save model without training
            if self.cf.epochs == 0:
                self.model.save_model()

        def training_loop(self, epoch, train_loader):
            # Train epoch
            for i, data in tqdm(enumerate(train_loader), desc="Epoch {}/{}".format(epoch, self.cf.epochs),
                                total=len(train_loader), file=sys.stdout):
                # Read Data
                inputs, labels = data

                n, c, w, h = inputs.size()
                inputs = Variable(inputs).cuda()
                self.inputs = inputs
                self.labels = Variable(labels).cuda()

                # Predict model
                self.model.optimizer.zero_grad()
                self.outputs = self.model.net(inputs)
                predictions = self.outputs.data.max(1)[1].cpu().numpy()

                # Compute gradients
                self.compute_gradients()

                # Compute batch stats
                self.train_loss.update(float(self.loss.cpu().item()), n)
                confm = compute_confusion_matrix(predictions, self.labels.cpu().data.numpy(), self.cf.num_classes,
                                                 self.cf.void_class)
                self.confm_list = self.confm_list + confm

                if self.cf.normalize_loss:
                    self.stats.train.loss = self.train_loss.avg
                else:
                    self.stats.train.loss = self.train_loss.avg

                if not self.cf.debug:
                    # Save stats
                    self.save_stats_batch((epoch - 1) * self.train_num_batches + i)

        def save_stats_epoch(self, epoch):
            # Save logger
            if epoch is not None:
                # Epoch loss tensorboard
                self.writer.add_scalar('losses/epoch', self.stats.train.loss, epoch)
                self.writer.add_scalar('metrics/accuracy', 100. * self.stats.train.acc, epoch)

        def save_stats_batch(self, batch):
            # Save logger
            if batch is not None:
                self.writer.add_scalar('losses/batch', self.stats.train.loss, batch)

        def compute_gradients(self):
            self.loss = self.model.loss(self.outputs, self.labels)
            self.loss.backward()
            self.model.optimizer.step()

        def compute_stats(self, confm_list, train_loss):
            TP_list, TN_list, FP_list, FN_list = extract_stats_from_confm(confm_list)
            mean_accuracy = compute_accuracy(TP_list, TN_list, FP_list, FN_list)
            self.stats.train.acc = np.nanmean(mean_accuracy)
            self.stats.train.loss = float(train_loss.avg.cpu().data)

        def validate_epoch(self, valid_set, valid_loader, early_stopping, epoch):
            if valid_set is not None and valid_loader is not None:
                # Set model in validation mode
                self.model.net.eval()

                self.validator.start(valid_set, valid_loader, 'Epoch Validation', epoch)
                print(self.stats)

                # Early stopping checking
                if self.cf.early_stopping:
                    if early_stopping.check(self.stats.train.loss, self.stats.val.loss, self.stats.val.mIoU,
                                            self.stats.val.acc, self.stats.val.f1score):
                        self.stop = True
                # Set model in training mode
                self.model.net.train()
Пример #18
0
def evaluate_single_epoch(config, model, dataloader, criterion, log_val, epoch):
    batch_time = AverageMeter()
    losses = AverageMeter()
    scores = AverageMeter()
    score_1 = AverageMeter()
    score_2 = AverageMeter()
    score_3 = AverageMeter()
    score_4 = AverageMeter()

    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, labels) in enumerate(dataloader):
            images = images.to(device)
            labels = labels.to(device)

            # logits = model(images)
            logits = model(images)['out']

            loss = criterion(logits, labels)
            losses.update(loss.item(), images.shape[0])

            preds = F.sigmoid(logits)

            score = dice_coef(preds, labels)
            score_1.update(score[0].item(), images.shape[0])
            score_2.update(score[1].item(), images.shape[0])
            score_3.update(score[2].item(), images.shape[0])
            score_4.update(score[3].item(), images.shape[0])
            scores.update(score.mean().item(), images.shape[0])

            batch_time.update(time.time() - end)
            end = time.time()

            if i % config.PRINT_EVERY == 0:
                print('[%2d/%2d] time: %.2f, loss: %.6f, score: %.4f [%.4f, %.4f, %.4f, %.4f]'
                      % (i, len(dataloader), batch_time.sum, loss.item(), score.mean().item(), score[0].item(), score[1].item(), score[2].item(), score[3].item()))

            del images, labels, logits, preds
            torch.cuda.empty_cache()

        log_val.write('[%d/%d] loss: %.6f, score: %.4f [%.4f, %.4f, %.4f, %.4f]\n'
                      % (epoch, config.TRAIN.NUM_EPOCHS, losses.avg, scores.avg, score_1.avg, score_2.avg, score_3.avg, score_4.avg))
        print('average loss over VAL epoch: %f' % losses.avg)

    return scores.avg, losses.avg
Пример #19
0
def train_single_epoch(config, model, dataloader, criterion, optimizer, log_train, epoch):
    batch_time = AverageMeter()
    losses = AverageMeter()
    scores = AverageMeter()
    score_1 = AverageMeter()
    score_2 = AverageMeter()
    score_3 = AverageMeter()
    score_4 = AverageMeter()

    model.train()

    end = time.time()
    for i, (images, labels) in enumerate(dataloader):
        optimizer.zero_grad()

        images = images.to(device)
        labels = labels.to(device)

        # logits = model(images)
        logits = model(images)['out']

        if config.LABEL_SMOOTHING:
            smoother = LabelSmoother()
            loss = criterion(logits, smoother(labels))
        else:
            loss = criterion(logits, labels)

        losses.update(loss.item(), images.shape[0])

        loss.backward()
        optimizer.step()

        preds = F.sigmoid(logits)

        score = dice_coef(preds, labels)
        score_1.update(score[0].item(), images.shape[0])
        score_2.update(score[1].item(), images.shape[0])
        score_3.update(score[2].item(), images.shape[0])
        score_4.update(score[3].item(), images.shape[0])
        scores.update(score.mean().item(), images.shape[0])

        batch_time.update(time.time() - end)
        end = time.time()

        if i % config.PRINT_EVERY == 0:
            print("[%d/%d][%d/%d] time: %.2f, loss: %.6f, score: %.4f [%.4f, %.4f, %.4f, %.4f], lr: %.6f"
                  % (epoch, config.TRAIN.NUM_EPOCHS, i, len(dataloader), batch_time.sum, loss.item(), score.mean().item(),
                     score[0].item(), score[1].item(), score[2].item(), score[3].item(),
                     optimizer.param_groups[0]['lr']))
                     # optimizer.param_groups[0]['lr'], optimizer.param_groups[1]['lr']))

        del images, labels, logits, preds
        torch.cuda.empty_cache()

    log_train.write('[%d/%d] loss: %.6f, score: %.4f, dice: [%.4f, %.4f, %.4f, %.4f], lr: %.6f\n'
                    % (epoch, config.TRAIN.NUM_EPOCHS, losses.avg, scores.avg, score_1.avg, score_2.avg, score_3.avg, score_4.avg,
                       optimizer.param_groups[0]['lr']))
                       # optimizer.param_groups[0]['lr'], optimizer.param_groups[1]['lr']))
    print('average loss over TRAIN epoch: %f' % losses.avg)
Пример #20
0
    def train(self, ii, logger):
        train_data, train_loader = self._get_data(flag='train')
        vali_data, vali_loader = self._get_data(flag='val')
        next_data, next_loader = self._get_data(flag='train')
        test_data, test_loader = self._get_data(flag='test')
        if self.args.rank == 1:
            train_data, train_loader = self._get_data(flag='train')

        path = os.path.join(self.args.path, str(ii))
        try:
            os.mkdir(path)
        except FileExistsError:
            pass
        time_now = time.time()

        train_steps = len(train_loader)
        early_stopping = EarlyStopping(patience=self.args.patience,
                                       verbose=True,
                                       rank=self.args.rank)

        W_optim, A_optim = self._select_optimizer()
        criterion = self._select_criterion()

        if self.args.use_amp:
            scaler = torch.cuda.amp.GradScaler()

        for epoch in range(self.args.train_epochs):
            iter_count = 0
            train_loss = []
            rate_counter = AverageMeter()
            Ag_counter, A_counter, Wg_counter, W_counter = AverageMeter(
            ), AverageMeter(), AverageMeter(), AverageMeter()

            self.model.train()
            epoch_time = time.time()
            for i, (trn_data, val_data, next_data) in enumerate(
                    zip(train_loader, vali_loader, next_loader)):
                for i in range(len(trn_data)):
                    trn_data[i], val_data[i], next_data[i] = trn_data[i].float(
                    ).to(self.device), val_data[i].float().to(
                        self.device), next_data[i].float().to(self.device)
                iter_count += 1
                A_optim.zero_grad()
                rate = self.arch.unrolled_backward(
                    self.args, trn_data, val_data, next_data,
                    W_optim.param_groups[0]['lr'], W_optim)
                rate_counter.update(rate)
                # for r in range(1, self.args.world_size):
                #     for n, h in self.model.named_H():
                #         if "proj.{}".format(r) in n:
                #             if self.args.rank <= r:
                #                 with torch.no_grad():
                #                     dist.all_reduce(h.grad)
                #                     h.grad *= self.args.world_size/r+1
                #             else:
                #                 z = torch.zeros(h.shape).to(self.device)
                #                 dist.all_reduce(z)
                for a in self.model.A():
                    with torch.no_grad():
                        dist.all_reduce(a.grad)
                a_g_norm = 0
                a_norm = 0
                n = 0
                for a in self.model.A():
                    a_g_norm += a.grad.mean()
                    a_norm += a.mean()
                    n += 1
                Ag_counter.update(a_g_norm / n)
                A_counter.update(a_norm / n)

                A_optim.step()

                W_optim.zero_grad()
                pred, true = self._process_one_batch(train_data, trn_data)
                loss = criterion(pred, true)
                train_loss.append(loss.item())

                if (i + 1) % 100 == 0:
                    logger.info(
                        "\tR{0} iters: {1}, epoch: {2} | loss: {3:.7f}".format(
                            self.args.rank, i + 1, epoch + 1, loss.item()))
                    speed = (time.time() - time_now) / iter_count
                    left_time = speed * (
                        (self.args.train_epochs - epoch) * train_steps - i)
                    logger.info(
                        '\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(
                            speed, left_time))
                    iter_count = 0
                    time_now = time.time()

                if self.args.use_amp:
                    scaler.scale(loss).backward()
                    scaler.step(W_optim)
                    scaler.update()
                else:
                    loss.backward()

                    w_g_norm = 0
                    w_norm = 0
                    n = 0
                    for w in self.model.W():
                        w_g_norm += w.grad.mean()
                        w_norm += w.mean()
                        n += 1
                    Wg_counter.update(w_g_norm / n)
                    W_counter.update(w_norm / n)

                    W_optim.step()

            logger.info("R{} Epoch: {} W:{} Wg:{} A:{} Ag:{} rate{}".format(
                self.args.rank, epoch + 1, W_counter.avg, Wg_counter.avg,
                A_counter.avg, Ag_counter.avg, rate_counter.avg))

            logger.info("R{} Epoch: {} cost time: {}".format(
                self.args.rank, epoch + 1,
                time.time() - epoch_time))
            train_loss = np.average(train_loss)
            vali_loss = self.vali(vali_data, vali_loader, criterion)
            test_loss = self.vali(test_data, test_loader, criterion)

            logger.info(
                "R{0} Epoch: {1}, Steps: {2} | Train Loss: {3:.7f} Vali Loss: {4:.7f} Test Loss: {5:.7f}"
                .format(self.args.rank, epoch + 1, train_steps, train_loss,
                        vali_loss, test_loss))
            early_stopping(vali_loss, self.model, path)

            flag = torch.tensor(
                [1]) if early_stopping.early_stop else torch.tensor([0])
            flag = flag.to(self.device)
            flags = [
                torch.tensor([1]).to(self.device),
                torch.tensor([1]).to(self.device)
            ]
            dist.all_gather(flags, flag)
            if flags[0].item() == 1 and flags[1].item() == 1:
                logger.info("Early stopping")
                break

            adjust_learning_rate(W_optim, epoch + 1, self.args)

        best_model_path = path + '/' + '{}_checkpoint.pth'.format(
            self.args.rank)
        self.model.load_state_dict(torch.load(best_model_path))

        return self.model
Пример #21
0
    def train_epoch(self, data_loader):

        pbar = ProgressBar(n_total=len(data_loader), desc='Training')
        tr_loss = AverageMeter()
        self.tmodel.eval()
        for step, batch in enumerate(data_loader):
            self.smodel.train()

            batch = tuple(t.to(self.device) for t in batch)

            input_ids, input_mask, trigger_mask, segment_ids, label_ids, input_lens, one_hot_labels = batch
            if not self.partial:
                one_hot_labels = None
            if not self.trigger:
                trigger_mask = None
            lens = input_lens.cpu().detach().numpy().tolist()
            with torch.no_grad():
                t_features, t_loss, = self.tmodel(input_ids, segment_ids,
                                                  input_mask, trigger_mask,
                                                  label_ids, lens,
                                                  one_hot_labels)
            # 	_,tag_ids = self.tmodel.crf._obtain_labels(t_features, self.id2label, input_lens)
            #
            #
            #
            # teacher_tags = []
            # max_seq_length=256
            # for tag in tag_ids:
            # 	tag+=[0]*(max_seq_length-len(tag))
            # 	teacher_tags.append(tag)
            # teacher_labels = torch.tensor([tag for tag in teacher_tags], dtype=torch.long).cuda()
            # s_features = self.smodel.forward_f(input_ids, segment_ids, input_mask, trigger_mask,)
            # combine_labels = torch.where(label_ids==0,teacher_labels,label_ids)
            #
            # loss = self.smodel.crf.calculate_loss(s_features, combine_labels, input_lens)
            '''
			s_partial_loss = self.smodel.crf.partial_loss(s_features, input_lens, one_hot_labels) # outer signal
			st_loss = self.smodel.crf.calculate_loss(s_features, teacher_labels, input_lens) # sequence-level teacher signal
			kld_loss = nn.KLDivLoss(reduction='none')(F.log_softmax(s_features,-1),F.softmax(t_features,-1)) #word-levle teacher signal
			bs = kld_loss.size(0)
			pad_kld_loss = kld_loss*(input_mask.float().unsqueeze(-1))
			reduced_kld_loss = torch.mean(torch.sum(pad_kld_loss.view(bs,-1),-1)/input_lens.float(),0)
			loss=(1-self.alpha)*s_partial_loss+self.alpha*(reduced_kld_loss+st_loss)
			'''
            s_features = self.smodel.forward_f(
                input_ids,
                segment_ids,
                input_mask,
                trigger_mask,
            )
            kld_loss = nn.KLDivLoss(reduction='none')(F.log_softmax(
                s_features, -1), F.softmax(t_features,
                                           -1))  # word-levle teacher signal
            bs = kld_loss.size(0)
            pad_kld_loss = kld_loss * (input_mask.float().unsqueeze(-1))
            loss = torch.mean(
                torch.sum(pad_kld_loss.view(bs, -1), -1) / input_lens.float(),
                0)

            if len(self.n_gpu.split(",")) >= 2:
                loss = loss.mean()
            if self.gradient_accumulation_steps > 1:
                loss = loss / self.gradient_accumulation_steps

            loss.backward()
            clip_grad_norm_(self.smodel.parameters(), self.grad_clip)
            if (step + 1) % self.gradient_accumulation_steps == 0:
                self.optimizer.step()
                self.optimizer.zero_grad()
                self.global_step += 1
            tr_loss.update(loss.item(), n=1)
            # self.tb_logger.scalar_summary ("s_partial_loss", s_partial_loss.item(), self.global_step)
            # self.tb_logger.scalar_summary('st_loss',st_loss.item(),self.global_step)
            # self.tb_logger.scalar_summary('kld_loss',reduced_kld_loss.item(),self.global_step)
            pbar(step=step, info={'loss': loss.item()})
            # if step%5==0:
            # 	self.logger.info("step:{},loss={:.4f}".format(self.global_step,loss.item()))

        info = {'loss': tr_loss.avg}
        if "cuda" in str(self.device):
            torch.cuda.empty_cache()
        return info
Пример #22
0
Файл: model.py Проект: LcDog/APL
    def train(self, src_loader, tgt_loader, writer):
        num_batches = min(len(src_loader), len(tgt_loader))
        for epoch in range(self.opt.num_epoch):

            content_loss_rec = AverageMeter()
            style_loss1_rec = AverageMeter()
            style_loss2_rec = AverageMeter()
            data_time_rec = AverageMeter()
            batch_time_rec = AverageMeter()

            tem_time = time.time()
            for batch_index, batch_data in enumerate(
                    zip(src_loader, tgt_loader)):
                iteration = batch_index + 1 + epoch * num_batches

                self.G.train()
                src_batch, tgt_batch = batch_data
                src_img, _, src_name = src_batch
                tgt_img, tgt_label, tgt_name = tgt_batch
                src_img_cuda = src_img.cuda()
                tgt_img_cuda = tgt_img.cuda()
                data_time_rec.update(time.time() - tem_time)

                rec_tgt = self.G(tgt_img_cuda)  # output [-1,1]
                if (batch_index + 1) % self.opt.show_img_freq == 0:
                    rec_results = rec_tgt.detach().clone().cpu()
                rec_tgt = self.de_normalize(rec_tgt)  # return to [0,1]

                content_loss = self.compute_content_loss(rec_tgt, tgt_img_cuda)
                style_loss1, style_loss2 =\
                    self.compute_style_loss(rec_tgt, src_img_cuda)
                loss = content_loss * self.lambda_values[0] +\
                       style_loss1 * self.lambda_values[1] +\
                       style_loss2 * self.lambda_values[2]
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                content_loss_rec.update(content_loss.item())
                style_loss1_rec.update(style_loss1.item())
                style_loss2_rec.update(style_loss2.item())
                writer.add_scalar('AA_content_loss', content_loss.item(),
                                  iteration)
                writer.add_scalar('AA_style_loss_1', style_loss1.item(),
                                  iteration)
                writer.add_scalar('AA_style_loss_2', style_loss2.item(),
                                  iteration)
                batch_time_rec.update(time.time() - tem_time)
                tem_time = time.time()

                if (batch_index + 1) % self.opt.print_freq == 0:
                    print(f'Epoch [{epoch+1:d}/{self.opt.num_epoch:d}]'
                          f'[{batch_index+1:d}/{num_batches:d}]\t'
                          f'Time: {batch_time_rec.avg:.2f}   '
                          f'Data: {data_time_rec.avg:.2f}   '
                          f'Loss: {content_loss_rec.avg:.2f}   '
                          f'Style1: {style_loss1_rec.avg:.2f}   '
                          f'Style2: {style_loss2_rec.avg:.2f}')
                if (batch_index + 1) % self.opt.show_img_freq == 0:
                    fig, axes = plt.subplots(5, 1, figsize=(12, 30))
                    axes = axes.flat
                    axes[0].imshow(self.to_image(rec_results[0, ...]))
                    axes[0].set_title(f'rec')
                    axes[1].imshow(self.to_image(tgt_img[0, ...]))
                    axes[1].set_title(tgt_name[0])

                    rec_seg = self.compute_seg_map(rec_results).cpu().numpy()
                    tgt_img_cuda = self.de_normalize(tgt_img_cuda)
                    ori_seg = self.compute_seg_map(tgt_img_cuda).cpu().numpy()
                    tgt_label = tgt_label.numpy()
                    rec_miu = _evaluate(rec_seg, tgt_label, 19)[3]
                    ori_miu = _evaluate(ori_seg, tgt_label, 19)[3]

                    axes[2].imshow(colorize_mask(rec_seg[0, ...]))
                    axes[2].set_title(f'rec_label_{rec_miu*100:.2f}')

                    axes[3].imshow(colorize_mask(ori_seg[0, ...]))
                    axes[3].set_title(f'ori_label_{ori_miu*100:.2f}')

                    gt_label = tgt_label[0, ...]
                    axes[4].imshow(colorize_mask(gt_label))
                    axes[4].set_title(f'gt_label')

                    writer.add_figure('A_rec', fig, iteration)
                if (batch_index + 1) % self.opt.checkpoint_freq == 0:
                    model_name = time.strftime(
                        '%m%d_%H%M_', time.localtime(
                            time.time())) + str(iteration) + '.pth'
                    torch.save(
                        self.G.module.state_dict(),
                        os.path.join(self.opt.save_path_prefix, model_name))
Пример #23
0
    def evaluate(self, valid_loader, model, criterion, epoch):
        batch_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()

        # switch to evaluate mode
        model.eval()

        end = time.time()
        with torch.no_grad():
            for step, (input, target) in enumerate(valid_loader):
                if self.args.gpus > 0:
                    input = input.cuda(non_blocking=True)
                    target = target.cuda(non_blocking=True)

                # compute output
                output = model(input)
                loss = criterion(output, target)

                # measure accuracy and record loss
                prec1, prec5 = accuracy(output.cpu().data,
                                        target.cpu().data,
                                        topk=(1, 5))
                losses.update(loss.cpu().data.item(), input.size(0))
                top1.update(prec1.item(), input.size(0))
                top5.update(prec5.item(), input.size(0))

                # measure elapsed time
                batch_time.update(time.time() - end)
                end = time.time()

                if step % self.args.print_freq == 0:
                    print('Test: [{0}/{1}]\t'
                          'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                          'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                          'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                          'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
                              step,
                              len(valid_loader),
                              batch_time=batch_time,
                              loss=losses,
                              top1=top1,
                              top5=top5))

        print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'.format(
            top1=top1, top5=top5))
        record = OrderedDict([
            ["epoch", epoch],
            ["time", batch_time.avg],
            ["loss", losses.avg],
            ["top1", top1.avg],
            ["top5", top5.avg],
        ])
        with open(self.valid_file, "a") as fp:
            fp.write(json.dumps(record) + "\n")
        return top1.avg
Пример #24
0
    def train(self, train_loader, model, criterion, optimizer, epoch):
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()
        samples_per_second = AverageMeter()

        # switch to train mode
        model.train()

        end = time.time()
        for step, (input, target) in enumerate(train_loader):
            # measure data loading time
            data_time.update(time.time() - end)
            if self.args.gpus > 0:
                input = input.cuda(non_blocking=True)
                target = target.cuda(non_blocking=True)

            # compute output
            output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            prec1, prec5 = accuracy(output.cpu().data,
                                    target.cpu().data,
                                    topk=(1, 5))
            losses.update(loss.cpu().data.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))
            top5.update(prec5.item(), input.size(0))

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # measure elapsed time / executed samples
            elapsed_time = time.time() - end
            batch_time.update(elapsed_time)

            total_samples = self.args.batch_size / elapsed_time  # forward samples per second
            samples_per_second.update(total_samples)

            end = time.time()

            if step % self.args.print_freq == 0:
                print(
                    'Epoch: [{0}][{1}/{2}]\t'
                    'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Data {data_time.val:.1f} (average {data_time.avg:.1f} samples/s)\t'
                    'Loss {loss.val:.4f} (average {loss.avg:.4f})\t'
                    'Prec@1 {top1.val:.3f} (average {top1.avg:.3f})\t'
                    'Prec@5 {top5.val:.3f} (average {top5.avg:.3f})'.format(
                        epoch, (step + 1),
                        len(train_loader),
                        batch_time=batch_time,
                        data_time=samples_per_second,
                        loss=losses,
                        top1=top1,
                        top5=top5))
                record = OrderedDict([
                    # ["iter", i / len(train_loader)],
                    ["epoch", epoch],
                    ["time", batch_time.val],
                    ["loss", losses.val],
                    ["top1", top1.val],
                    ["top5", top5.val],
                ])
                with open(self.train_file, "a") as fp:
                    fp.write(json.dumps(record) + "\n")
        self.writer.add_scalar("time", batch_time.val, epoch)
        self.writer.add_scalar("loss", losses.val, epoch)
        self.writer.add_scalar("top1", top1.val, epoch)
        self.writer.add_scalar("top5", top5.val, epoch)
Пример #25
0
    class train(object):
        def __init__(self, logger_stats, model, cf, validator, stats, msg):
            # Initialize training variables
            self.logger_stats = logger_stats
            self.model = model
            self.cf = cf
            self.validator = validator
            self.logger_stats.write('\n- Starting train <--- \n')
            self.curr_epoch = 1 if self.model.best_stats.epoch == 0 else self.model.best_stats.epoch
            self.stop = False
            self.stats = stats
            self.best_acc = 0
            self.msg = msg
            self.loss = None
            self.outputs = None
            self.labels = None
            self.writer = SummaryWriter(
                os.path.join(cf.tensorboard_path, 'train'))

        def start(self,
                  train_loader,
                  train_set,
                  valid_set=None,
                  valid_loader=None):
            self.train_num_batches = math.ceil(train_set.num_images /
                                               float(self.cf.train_batch_size))
            self.val_num_batches = 0 if valid_set is None else math.ceil(valid_set.num_images / \
                                                                    float(self.cf.valid_batch_size))
            # Define early stopping control
            if self.cf.early_stopping:
                early_Stopping = Early_Stopping(self.cf)
            else:
                early_Stopping = None

            prev_msg = '\nTotal estimated training time...\n'
            self.global_bar = ProgressBar(
                (self.cf.epochs + 1 - self.curr_epoch) *
                (self.train_num_batches + self.val_num_batches),
                lenBar=20)
            self.global_bar.set_prev_msg(prev_msg)

            # Train process
            for epoch in range(self.curr_epoch, self.cf.epochs + 1):
                # Shuffle train data
                train_set.update_indexes()

                # Initialize logger
                epoch_time = time.time()
                self.logger_stats.write('\t ------ Epoch: ' + str(epoch) +
                                        ' ------ \n')

                # Initialize epoch progress bar
                self.msg.accum_str = '\n\nEpoch %d/%d estimated time...\n' % \
                                     (epoch, self.cf.epochs)
                epoch_bar = ProgressBar(self.train_num_batches, lenBar=20)
                epoch_bar.update(show=False)

                # Initialize stats
                self.stats.epoch = epoch
                self.train_loss = AverageMeter()
                self.confm_list = np.zeros(
                    (self.cf.num_classes, self.cf.num_classes))

                # Train epoch
                self.training_loop(epoch, train_loader, epoch_bar)

                # Save stats
                self.stats.train.conf_m = self.confm_list
                self.compute_stats(np.asarray(self.confm_list),
                                   self.train_loss)
                self.save_stats_epoch(epoch)
                self.logger_stats.write_stat(
                    self.stats.train, epoch,
                    os.path.join(self.cf.train_json_path,
                                 'train_epoch_' + str(epoch) + '.json'))

                # Validate epoch
                self.validate_epoch(valid_set, valid_loader, early_Stopping,
                                    epoch, self.global_bar)

                # Update scheduler
                if self.model.scheduler is not None:
                    self.model.scheduler.step(self.stats.val.loss)

                # Saving model if score improvement
                new_best = self.model.save(self.stats)
                if new_best:
                    self.logger_stats.write_best_stats(self.stats, epoch,
                                                       self.cf.best_json_file)

                # Update display values
                self.update_messages(epoch, epoch_time, new_best)

                if self.stop:
                    return

            # Save model without training
            if self.cf.epochs == 0:
                self.model.save_model()

        def training_loop(self, epoch, train_loader, epoch_bar):
            # Train epoch
            for i, data in enumerate(train_loader):
                # Read Data
                inputs, labels = data

                N, w, h, c = inputs.size()
                inputs = Variable(inputs).cuda()
                self.inputs = inputs
                self.labels = Variable(labels).cuda()

                # Predict model
                self.model.optimizer.zero_grad()
                self.outputs = self.model.net(inputs)
                predictions = self.outputs.data.max(1)[1].cpu().numpy()

                # Compute gradients
                self.compute_gradients()

                # Compute batch stats
                self.train_loss.update(float(self.loss.cpu().item()), N)
                confm = compute_confusion_matrix(
                    predictions,
                    self.labels.cpu().data.numpy(), self.cf.num_classes,
                    self.cf.void_class)
                self.confm_list = map(operator.add, self.confm_list, confm)

                if self.cf.normalize_loss:
                    self.stats.train.loss = self.train_loss.avg
                else:
                    self.stats.train.loss = self.train_loss.avg

                if not self.cf.debug:
                    # Save stats
                    self.save_stats_batch((epoch - 1) *
                                          self.train_num_batches + i)

                    # Update epoch messages
                    self.update_epoch_messages(epoch_bar, self.global_bar,
                                               self.train_num_batches, epoch,
                                               i)

        def save_stats_epoch(self, epoch):
            # Save logger
            if epoch is not None:
                # Epoch loss tensorboard
                self.writer.add_scalar('losses/epoch', self.stats.train.loss,
                                       epoch)
                self.writer.add_scalar('metrics/accuracy',
                                       100. * self.stats.train.acc, epoch)

        def save_stats_batch(self, batch):
            # Save logger
            if batch is not None:
                self.writer.add_scalar('losses/batch', self.stats.train.loss,
                                       batch)

        def compute_gradients(self):
            self.loss = self.model.loss(self.outputs, self.labels)
            self.loss.backward()
            self.model.optimizer.step()

        def compute_stats(self, confm_list, train_loss):
            TP_list, TN_list, FP_list, FN_list = extract_stats_from_confm(
                confm_list)
            mean_accuracy = compute_accuracy(TP_list, TN_list, FP_list,
                                             FN_list)
            self.stats.train.acc = np.nanmean(mean_accuracy)
            self.stats.train.loss = float(train_loss.avg.cpu().data)

        def validate_epoch(self, valid_set, valid_loader, early_Stopping,
                           epoch, global_bar):

            if valid_set is not None and valid_loader is not None:
                # Set model in validation mode
                self.model.net.eval()

                self.validator.start(valid_set,
                                     valid_loader,
                                     'Epoch Validation',
                                     epoch,
                                     global_bar=global_bar)

                # Early stopping checking
                if self.cf.early_stopping:
                    early_Stopping.check(self.stats.train.loss,
                                         self.stats.val.loss,
                                         self.stats.val.mIoU,
                                         self.stats.val.acc)
                    if early_Stopping.stop == True:
                        self.stop = True
                # Set model in training mode
                self.model.net.train()

        def update_messages(self, epoch, epoch_time):
            # Update logger
            epoch_time = time.time() - epoch_time
            self.logger_stats.write('\t Epoch step finished: %ds \n' %
                                    (epoch_time))

            # Compute best stats
            self.msg.msg_stats_last = '\nLast epoch: acc = %.2f, loss = %.5f\n' % (
                100 * self.stats.val.acc, self.stats.val.loss)
            if self.best_acc < self.stats.val.acc:
                self.msg.msg_stats_best = 'Best case: epoch = %d, acc = %.2f, loss = %.5f\n' % (
                    epoch, 100 * self.stats.val.acc, self.stats.val.loss)

                msg_confm = self.stats.val.get_confm_str()
                self.logger_stats.write(msg_confm)
                self.msg.msg_stats_best = self.msg.msg_stats_best + msg_confm

                self.best_acc = self.stats.val.acc

        def update_epoch_messages(self, epoch_bar, global_bar,
                                  train_num_batches, epoch, batch):
            # Update progress bar
            epoch_bar.set_msg('loss = %.5f' % self.stats.train.loss)
            self.msg.last_str = epoch_bar.get_message(step=True)
            global_bar.set_msg(self.msg.accum_str + self.msg.last_str + self.msg.msg_stats_last + \
                               self.msg.msg_stats_best)
            global_bar.update()

            # writer.add_scalar('train_loss', train_loss.avg, curr_iter)

            # Display progress
            curr_iter = (epoch - 1) * train_num_batches + batch + 1
            if (batch + 1) % math.ceil(train_num_batches / 20.) == 0:
                self.logger_stats.write(
                    '[Global iteration %d], [iter %d / %d], [train loss %.5f] \n'
                    % (curr_iter, batch + 1, train_num_batches,
                       self.stats.train.loss))
Пример #26
0
def inference(data_path, threshold=0.3):

    name_index, index_name = read_class_names(args.classes)
    output_col = ['image_name'] + list(name_index.keys())
    submission_col = ['image_name', 'tags']
    inference_dir = None

    #----------------------------dataset generator--------------------------------
    test_transform = get_transform(size=args.image_size, mode='test')
    test_dataset = PlanetDataset(image_root=data_path,
                                 phase='test',
                                 img_type=args.image_type,
                                 img_size=args.image_size,
                                 transform=test_transform)

    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=args.num_works)
    # ---------------------------------load model and param--------------------------------
    model = build_model(model_name=args.model_name,
                        num_classes=args.num_classes,
                        global_pool=args.global_pool)

    if args.best_checkpoint is not None:
        assert os.path.isfile(args.best_checkpoint), '{} not found'.format(
            args.best_checkpoint)
        checkpoint = torch.load(args.best_checkpoint)
        print('Restoring model with {} architecture...'.format(
            checkpoint['arch']))

        # load model weights
        if use_cuda:
            if checkpoint['num_gpu'] > 1:
                model = torch.nn.DataParallel(model, device_ids=gpu_ids).cuda()
            else:
                model.cuda()
        else:
            if checkpoint['num_gpu'] > 1:
                model = torch.nn.DataParallel(model)
            else:
                model.cuda()
        model.load_state_dict(checkpoint['state_dict'])

        # update threshold
        if 'threshold' in checkpoint:
            threshold = checkpoint['threshold']
            threshold = torch.tensor(threshold, dtype=torch.float32)
            print('Using thresholds:', threshold)
        else:
            threshold = 0.3

        if use_cuda:
            threshold = threshold.cuda()

        # generate save path
        inference_dir = os.path.join(
            os.path.normcase(args.inference_path),
            '{}-f{}-{:.6f}'.format(checkpoint['arch'], checkpoint['fold'],
                                   checkpoint['f2']))
        os.makedirs(inference_dir, exist_ok=True)

        print('Model restored from file: {}'.format(args.best_checkpoint))
    else:
        assert False and "No checkpoint specified"

    # -------------------------------------inference---------------------------------------
    model.eval()

    batch_time_meter = AverageMeter()
    results_raw = []
    results_label = []
    results_submission = []

    since_time = time.time()
    pbar = tqdm(enumerate(test_loader))
    try:
        with torch.no_grad():
            start = time.time()
            for batch_idx, (inputs, _, indices) in pbar:
                if use_cuda:
                    inputs = inputs.cuda()
                # input_var = autograd.Variable(input, volatile=True)
                # input_var = torch.autograd.Variable(inputs)
                input_var = inputs
                outputs = model(input_var)

                if args.multi_label:
                    if args.loss == 'nll':
                        outputs = F.softmax(outputs)
                    else:
                        outputs = torch.sigmoid(outputs)

                expand_threshold = torch.unsqueeze(threshold,
                                                   0).expand_as(outputs)
                output_labels = (outputs.data > expand_threshold).byte()

                # move data to CPU and collect
                outputs = outputs.cpu().data.numpy()
                output_labels = output_labels.cpu().numpy()
                indices = indices.cpu().numpy().flatten()

                for index, output, output_label in zip(indices, outputs,
                                                       output_labels):

                    image_name = os.path.splitext(
                        os.path.basename(test_dataset.images[index]))[0]
                    results_raw.append([image_name] + list(output))
                    results_label.append([image_name] + list(output_label))
                    results_submission.append(
                        [image_name] +
                        [index_to_tag(output_label, index_name)])

                batch_time_meter.update(time.time() - start)
                if batch_idx % args.summary_iter == 0:
                    print(
                        'Inference: [{}/{} ({:.0f}%)]  '
                        'Time: {batch_time.val:.3f}s, {rate:.3f}/s  '
                        '({batch_time.avg:.3f}s, {rate_avg:.3f}/s)  '.format(
                            batch_idx * len(inputs),
                            len(test_loader.sampler),
                            100. * batch_idx / len(test_loader),
                            batch_time=batch_time_meter,
                            rate=input_var.size(0) / batch_time_meter.val,
                            rate_avg=input_var.size(0) / batch_time_meter.avg))

                start = time.time()

    except KeyboardInterrupt:
        pass

    results_raw_df = pd.DataFrame(results_raw, columns=output_col)
    results_raw_df.to_csv(os.path.join(inference_dir, 'results_raw.csv'),
                          index=False)
    results_label_df = pd.DataFrame(results_label, columns=output_col)
    results_label_df.to_csv(os.path.join(inference_dir, 'results_thr.csv'),
                            index=False)
    results_submission_df = pd.DataFrame(results_submission,
                                         columns=submission_col)
    results_submission_df.to_csv(os.path.join(inference_dir, 'submission.csv'),
                                 index=False)

    time_elapsed = time.time() - since_time
    print('*** Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    class validation:
        def __init__(self, logger_stats, model, cf, stats, msg):
            # Initialize validation variables
            self.logger_stats = logger_stats
            self.model = model
            self.cf = cf
            self.stats = stats
            self.msg = msg
            self.writer = SummaryWriter(os.path.join(cf.tensorboard_path, 'validation'))

        def start(self, valid_set, valid_loader, mode='Validation', epoch=None, save_folder=None):
            confm_list = np.zeros((self.cf.num_classes, self.cf.num_classes))

            self.val_loss = AverageMeter()

            # Validate model
            if self.cf.problem_type == 'detection':
                self.validation_loop(epoch, valid_loader, valid_set, save_folder)
            else:
                self.validation_loop(epoch, valid_loader, valid_set, confm_list)

            # Compute stats
            self.compute_stats(np.asarray(self.stats.val.conf_m), self.val_loss)

            # Save stats
            self.save_stats(epoch, mode)
            if mode == 'Epoch Validation':
                self.logger_stats.write_stat(self.stats.train, epoch,
                                             os.path.join(self.cf.train_json_path,
                                                          'valid_epoch_' + str(epoch) + '.json'))
            elif mode == 'Validation':
                self.logger_stats.write_stat(self.stats.val, epoch, self.cf.val_json_file)
            elif mode == 'Test':
                self.logger_stats.write_stat(self.stats.val, epoch, self.cf.test_json_file)

        def validation_loop(self, epoch, valid_loader, valid_set, confm_list):
            for vi, data in tqdm(enumerate(valid_loader), desc="Validation", total=len(valid_loader),
                                 file=sys.stdout):
                # Read data
                inputs, gts = data
                n_images, w, h, c = inputs.size()
                inputs = Variable(inputs).cuda()
                gts = Variable(gts).cuda()

                # Predict model
                with torch.no_grad():
                    outputs = self.model.net(inputs)
                    predictions = outputs.data.max(1)[1].cpu().numpy()

                    # Compute batch stats
                    self.val_loss.update(float(self.model.loss(outputs, gts).cpu().item() / n_images), n_images)
                    confm = compute_confusion_matrix(predictions, gts.cpu().data.numpy(), self.cf.num_classes,
                                                     self.cf.void_class)
                    confm_list = confm_list + confm

                # Save epoch stats
                self.stats.val.conf_m = confm_list
                if not self.cf.normalize_loss:
                    self.stats.val.loss = self.val_loss.avg
                else:
                    self.stats.val.loss = self.val_loss.avg

                # Save predictions and generate overlaping
                self.update_tensorboard(inputs.cpu(), gts.cpu(),
                                        predictions, epoch, range(vi * self.cf.valid_batch_size,
                                                                  vi * self.cf.valid_batch_size +
                                                                  np.shape(predictions)[0]),
                                        valid_set.num_images)

        def update_tensorboard(self, inputs, gts, predictions, epoch, indexes, val_len):
            pass

        def compute_stats(self, confm_list, val_loss):
            TP_list, TN_list, FP_list, FN_list = extract_stats_from_confm(confm_list)
            mean_accuracy = compute_accuracy(TP_list, TN_list, FP_list, FN_list)
            self.stats.val.acc = np.nanmean(mean_accuracy)
            self.stats.val.loss = val_loss.avg

        def save_stats(self, epoch, mode):
            # Save logger
            if epoch is not None:
                self.logger_stats.write('----------------- Epoch scores summary ------------------------- \n')
                self.logger_stats.write('[epoch %d], [val loss %.5f], [acc %.2f] \n' % (
                    epoch, self.stats.val.loss, 100 * self.stats.val.acc))
                self.logger_stats.write('---------------------------------------------------------------- \n')
            else:
                self.logger_stats.write('----------------- Scores summary -------------------- \n')
                self.logger_stats.write('[%s loss %.5f], [acc %.2f] \n' % (mode,
                    self.stats.val.loss, 100 * self.stats.val.acc))
                self.logger_stats.write('---------------------------------------------------------------- \n')
Пример #28
0
def evaluate_single_epoch(config, model, dataloader, criterion, log_val, epoch,
                          writer, dataset_size):
    batch_time = AverageMeter()
    losses = AverageMeter()
    scores = AverageMeter()
    eval_loss = tf.keras.metrics.Mean(name='eval_loss')
    eval_accuracy = tf.keras.metrics.BinaryAccuracy(name='eval_accuracy')
    end = time.time()
    for i, (images, labels) in enumerate(dataloader):

        preds = model(images)

        loss = criterion(labels, preds)
        eval_loss(loss)
        loss_mean = eval_loss.result().numpy()

        losses.update(loss_mean, 1)
        eval_accuracy(labels, preds)
        score = eval_accuracy.result().numpy()
        scores.update(score, 1)

        batch_time.update(time.time() - end)
        end = time.time()

        if i % config.PRINT_EVERY == 0:
            print('[%2d/%2d] time: %.2f, loss: %.6f, score: %.4f' %
                  (i, dataset_size, batch_time.sum, loss_mean, score))

        if i < 5:  # just first time prinitng..
            iteration = dataset_size * epoch + i
            annotated_images = utils.tools.annotate_to_images(
                images, labels, preds.numpy())
            for idx, annotated_image in enumerate(annotated_images):
                writer.add_image(
                    'val/{}_image_{}_class_{}'.format(i, int(idx / 8),
                                                      idx % 8),
                    annotated_image, iteration)

        del images, labels, preds
        ## end of epoch. break..
        if i > dataset_size / config.EVAL.BATCH_SIZE: break
    writer.add_scalar('val/loss', losses.avg, epoch)
    writer.add_scalar('val/score', scores.avg, epoch)
    log_val.write('[%d/%d] loss: %.6f, score: %.4f\n' %
                  (epoch, config.TRAIN.NUM_EPOCHS, losses.avg, scores.avg))
    print('average loss over VAL epoch: %f' % losses.avg)

    return scores.avg, losses.avg
Пример #29
0
def train_single_epoch(config, model, dataloader, criterion, optimizer,
                       log_train, epoch, writer, dataset_size):
    batch_time = AverageMeter()
    losses = AverageMeter()
    scores = AverageMeter()
    # if epoch > 3 and epoch <6:
    #     optimizer.learning_rate = 0.001
    # elif epoch > 6:
    #     optimizer.learning_rate = 0.001

    train_loss = tf.keras.metrics.Mean(name='train_loss')
    train_accuracy = tf.keras.metrics.BinaryAccuracy(name='train_accuracy')
    smoother = LabelSmoother()
    end = time.time()
    for i, (images, labels) in enumerate(dataloader):
        with tf.GradientTape() as grad_tape:
            # (N, H, W, C)
            # (N, 512, 288, 1)
            preds = model(images)
            # expand_labels = tf.expand_dims(labels, -1)
            # cv2.imshow('images',np.uint8(images[0]*255))
            # cv2.imshow('labels',np.uint8(labels[0]*255))
            # cv2.waitKey()

            if config.LOSS.LABEL_SMOOTHING:
                loss = criterion(smoother(labels), preds)
            else:
                loss = criterion(labels, preds)

        gradients = grad_tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        # preds = tf.sigmoid(logits)
        train_loss(loss)
        train_accuracy(labels, preds)
        # print(epoch, "loss: ", train_loss.result().numpy(), "acc: ", train_accuracy.result().numpy(), "step: ", i)
        ### images.shape[0] is 1???
        losses.update(train_loss.result().numpy(), 1)
        scores.update(train_accuracy.result().numpy(), 1)

        batch_time.update(time.time() - end)
        end = time.time()
        dataloader_len = dataset_size / config.TRAIN.BATCH_SIZE
        if i % config.PRINT_EVERY == 0:
            print(
                "[%d/%d][%d/%d] time: %.2f, loss: %.6f, score: %.4f, lr: %f" %
                (epoch, config.TRAIN.NUM_EPOCHS, i, dataloader_len,
                 batch_time.sum, train_loss.result().numpy(),
                 train_accuracy.result().numpy(),
                 optimizer.learning_rate.numpy()))

        if i == 0:
            iteration = dataloader_len * epoch + i
            annotated_images = utils.tools.annotate_to_images(
                images, labels, preds.numpy())
            for idx, annotated_image in enumerate(annotated_images):
                writer.add_image(
                    'train/image_{}_class_{}'.format(int(idx / 8), idx % 8),
                    annotated_image, iteration)

        del images, labels, preds
        ## end of epoch. break..
        if i > dataset_size / config.TRAIN.BATCH_SIZE: break
    writer.add_scalar('train/score', scores.avg, epoch)
    writer.add_scalar('train/loss', losses.avg, epoch)
    writer.add_scalar('train/lr', optimizer.learning_rate.numpy(), epoch)
    log_train.write('[%d/%d] loss: %.6f, score: %.4f, lr: %f\n' %
                    (epoch, config.TRAIN.NUM_EPOCHS, losses.avg, scores.avg,
                     optimizer.learning_rate.numpy()))
    print('average loss over TRAIN epoch: %f' % losses.avg)
Пример #30
0
def cal_acc(data_list, pred_folder, classes):
    intersection_meter = AverageMeter()
    union_meter = AverageMeter()
    target_meter = AverageMeter()

    for i, (image_path, target_path) in enumerate(data_list):
        image_name = image_path.split('/')[-1].split('.')[0]
        pred = cv2.imread(os.path.join(pred_folder, image_name+'.png'), cv2.IMREAD_GRAYSCALE)
        target = cv2.imread(target_path, cv2.IMREAD_GRAYSCALE)
        intersection, union, target = intersection_union(pred, target, classes)
        intersection_meter.update(intersection)
        union_meter.update(union)
        target_meter.update(target)
        # accuracy = sum(intersection_meter.val) / (sum(target_meter.val) + 1e-10)
        # print('Evaluating {0}/{1} on image {2}, accuracy {3:.4f}.'.format(i + 1, len(data_list), image_name+'.png', accuracy))

    iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
    accuracy_class = intersection_meter.sum / (target_meter.sum + 1e-10)
    mIoU = np.mean(iou_class)
    mAcc = np.mean(accuracy_class)
    allAcc = sum(intersection_meter.sum) / (sum(target_meter.sum) + 1e-10)

    print('Eval result: mIoU/mAcc/allAcc {:.4f}/{:.4f}/{:.4f}.'.format(mIoU, mAcc, allAcc))