Ejemplo n.º 1
0
def train(trainloader, model, lemniscate, criterion, optimizer, epoch,
          train_writer):

    losses = MetricTracker()

    model.train()

    for idx, data in enumerate(tqdm(trainloader, desc="training")):

        imgs = data['img'].to(torch.device("cuda"))
        index = data["idx"].to(torch.device("cuda"))

        feature = model(imgs)
        output = lemniscate(feature, index)
        loss = criterion(output, index)

        optimizer.zero_grad()

        loss.backward()
        optimizer.step()

        losses.update(loss.item(), imgs.size(0))

    train_writer.add_scalar("loss", losses.avg, epoch)

    print(f'Train loss: {losses.avg:.6f}')
Ejemplo n.º 2
0
def train(trainloader, model, lemniscate, criterion, optimizer, epoch,
          train_writer):

    losses = MetricTracker()

    model.train()

    if (epoch + 1) >= args.start_prune and (epoch + 1) % 10 == 0:

        # checkpoint = torch.load(os.path.join(checkpoint_dir, sv_name + '_model_best.pth.tar'))

        # model.load_state_dict(checkpoint['model_state_dict'])
        # lemniscate = checkpoint['lemniscate']

        model.eval()

        for idx, data in enumerate(tqdm(trainloader, desc="training")):

            imgs = data['img'].to(torch.device("cuda"))
            index = data["idx"].to(torch.device("cuda"))

            feature = model(imgs)
            output = lemniscate(feature, index)

            criterion.update_weight(output, index)

        # checkpoint = torch.load(os.path.join(checkpoint_dir, sv_name + '_checkpoint.pth.tar'))
        # model.load_state_dict(checkpoint['model_state_dict'])

        model.train()

    for idx, data in enumerate(tqdm(trainloader, desc="training")):

        imgs = data['img'].to(torch.device("cuda"))
        index = data["idx"].to(torch.device("cuda"))

        feature = model(imgs)
        output = lemniscate(feature, index)
        loss = criterion(output, index)

        optimizer.zero_grad()

        loss.backward()
        optimizer.step()

        losses.update(loss.item(), imgs.size(0))

    train_writer.add_scalar("loss", losses.avg, epoch)
    print(f'Train loss: {losses.avg:.6f}')
Ejemplo n.º 3
0
def trainMoCo(epoch, trainloader, model, model_ema, contrast, criterion,
              optimizer, train_writer):

    loss_meter = MetricTracker()

    model.train()
    model_ema.eval()

    def set_bn_train(m):
        classname = m.__class__.__name__
        if classname.find('BatchNorm') != -1:
            m.train()

    model_ema.apply(set_bn_train)

    for idx, data in enumerate(tqdm(trainloader, desc="training")):

        imgs = data['anchor'].to(torch.device("cuda"))
        aug_imgs = data['neighbor'].to(torch.device("cuda"))
        index = data['idx'].to(torch.device("cuda"))

        bsz = imgs.size(0)

        shuffle_ids, reverse_ids = get_shuffle_ids(bsz)

        feat_q = model(imgs)
        with torch.no_grad():
            aug_imgs = aug_imgs[shuffle_ids]
            feat_k = model_ema(aug_imgs)
            feat_k = feat_k[reverse_ids]

        out = contrast(feat_q, feat_k)

        loss = criterion(out)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        loss_meter.update(loss.item(), bsz)
        moment_update(model, model_ema, args.alpha)

    info = {
        "Loss": loss_meter.avg,
    }
    for tag, value in info.items():
        train_writer.add_scalar(tag, value, epoch)

    print('Train Loss: {:.6f} '.format(loss_meter.avg))
Ejemplo n.º 4
0
def train(train_loader, model, criterion, optimizer, train_writer, epoch):

    # train_acc = MetricTracker()
    train_loss = MetricTracker()
    # train_IoU = MetricTracker()
    # train_BCE = metrics.MetricTracker()
    # train_DICE = metrics.MetricTracker()

    meters = {"train_loss": train_loss}

    model.train()

    for idx, data in enumerate(
            tqdm(train_loader, desc="training", ascii=True, ncols=20)):

        meters = make_train_step(idx, data, model, optimizer, criterion,
                                 meters)

    info = {
        "Loss": meters["train_loss"].avg,
        # "Acc": meters["train_acc"].avg,
        # "IoU": meters["train_IoU"].avg
    }

    for tag, value in info.items():
        train_writer.add_scalar(tag, value, epoch)

    print('Train Loss: {:.6f}'.format(
        meters["train_loss"].avg,
        # meters["train_acc"].avg,
        # meters["train_IoU"].avg
    ))

    return meters
Ejemplo n.º 5
0
def val(val_dataloader, model, val_writer, epoch):

    val_Dice = MetricTracker()

    model.eval()

    with torch.no_grad():
        for idx, data in enumerate(tqdm(val_dataloader, desc="val", ncols=10)):

            inputs = data['sat_img'].cuda()
            labels = data['map_img'].cuda()

            outputs = model(inputs)

            outputs = torch.argmax(outputs, dim=1).float()

            val_Dice.update(dice_coeff(outputs, labels), outputs.size(0))

    info = {"Dice": val_Dice.avg}

    for tag, value in info.items():
        val_writer.add_scalar(tag, value, epoch)

    print('Val Dice: {:.6f}'.format(val_Dice.avg))
Ejemplo n.º 6
0
    def __init__(self,
                 model,
                 optimizer,
                 config,
                 data_loader,
                 valid_data_loader=None,
                 lr_scheduler=None,
                 max_len_step=None):
        '''

        :param model:
        :param optimizer:
        :param config:
        :param data_loader:
        :param valid_data_loader:
        :param lr_scheduler:
        :param max_len_step:  controls number of batches(steps) in each epoch.
        '''
        self.config = config
        self.distributed = config['distributed']
        if self.distributed:
            self.local_master = (config['local_rank'] == 0)
            self.global_master = (dist.get_rank() == 0)
        else:
            self.local_master = True
            self.global_master = True
        self.logger = config.get_logger(
            'trainer',
            config['trainer']['log_verbosity']) if self.local_master else None

        # setup GPU device if available, move model into configured device
        self.device, self.device_ids = self._prepare_device(
            config['local_rank'], config['local_world_size'])
        self.model = model.to(self.device)

        self.optimizer = optimizer

        cfg_trainer = config['trainer']
        self.epochs = cfg_trainer['epochs']
        self.save_period = cfg_trainer['save_period']
        monitor_open = cfg_trainer['monitor_open']
        if monitor_open:
            self.monitor = cfg_trainer.get('monitor', 'off')
        else:
            self.monitor = 'off'

        # configuration to monitor model performance and save best
        if self.monitor == 'off':
            self.monitor_mode = 'off'
            self.monitor_best = 0
        else:
            self.monitor_mode, self.monitor_metric = self.monitor.split()
            assert self.monitor_mode in ['min', 'max']

            self.monitor_best = inf if self.monitor_mode == 'min' else -inf
            self.early_stop = cfg_trainer.get('early_stop', inf)
            self.early_stop = inf if self.early_stop == -1 else self.early_stop

        self.start_epoch = 1

        if self.local_master:
            self.checkpoint_dir = config.save_dir
            # setup visualization writer instance
            self.writer = TensorboardWriter(config.log_dir, self.logger,
                                            cfg_trainer['tensorboard'])

        # load checkpoint for resume training
        if config.resume is not None:
            self._resume_checkpoint(config.resume)

        # load checkpoint following load to multi-gpu, avoid 'module.' prefix
        if self.config['trainer']['sync_batch_norm'] and self.distributed:
            self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                self.model)

        if self.distributed:
            self.model = DDP(self.model,
                             device_ids=self.device_ids,
                             output_device=self.device_ids[0],
                             find_unused_parameters=True)

        self.data_loader = data_loader
        if max_len_step is None:  # max length of iteration step of every epoch
            # epoch-based training
            self.len_step = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_step = max_len_step
        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.lr_scheduler = lr_scheduler

        log_step = self.config['trainer']['log_step_interval']
        self.log_step = log_step if log_step != -1 and 0 < log_step < self.len_step else int(
            np.sqrt(data_loader.batch_size))

        val_step_interval = self.config['trainer']['val_step_interval']
        # self.val_step_interval = val_step_interval if val_step_interval!= -1 and 0 < val_step_interval < self.len_step\
        #                                             else int(np.sqrt(data_loader.batch_size))
        self.val_step_interval = val_step_interval

        self.gl_loss_lambda = self.config['trainer']['gl_loss_lambda']

        self.train_loss_metrics = MetricTracker(
            'loss',
            'gl_loss',
            'crf_loss',
            writer=self.writer if self.local_master else None)
        self.valid_f1_metrics = SpanBasedF1MetricTracker(iob_labels_vocab_cls)
Ejemplo n.º 7
0
class Trainer:
    """
    Trainer class
    """
    def __init__(self,
                 model,
                 optimizer,
                 config,
                 data_loader,
                 valid_data_loader=None,
                 lr_scheduler=None,
                 max_len_step=None):
        '''

        :param model:
        :param optimizer:
        :param config:
        :param data_loader:
        :param valid_data_loader:
        :param lr_scheduler:
        :param max_len_step:  controls number of batches(steps) in each epoch.
        '''
        self.config = config
        self.distributed = config['distributed']
        if self.distributed:
            self.local_master = (config['local_rank'] == 0)
            self.global_master = (dist.get_rank() == 0)
        else:
            self.local_master = True
            self.global_master = True
        self.logger = config.get_logger(
            'trainer',
            config['trainer']['log_verbosity']) if self.local_master else None

        # setup GPU device if available, move model into configured device
        self.device, self.device_ids = self._prepare_device(
            config['local_rank'], config['local_world_size'])
        self.model = model.to(self.device)

        self.optimizer = optimizer

        cfg_trainer = config['trainer']
        self.epochs = cfg_trainer['epochs']
        self.save_period = cfg_trainer['save_period']
        monitor_open = cfg_trainer['monitor_open']
        if monitor_open:
            self.monitor = cfg_trainer.get('monitor', 'off')
        else:
            self.monitor = 'off'

        # configuration to monitor model performance and save best
        if self.monitor == 'off':
            self.monitor_mode = 'off'
            self.monitor_best = 0
        else:
            self.monitor_mode, self.monitor_metric = self.monitor.split()
            assert self.monitor_mode in ['min', 'max']

            self.monitor_best = inf if self.monitor_mode == 'min' else -inf
            self.early_stop = cfg_trainer.get('early_stop', inf)
            self.early_stop = inf if self.early_stop == -1 else self.early_stop

        self.start_epoch = 1

        if self.local_master:
            self.checkpoint_dir = config.save_dir
            # setup visualization writer instance
            self.writer = TensorboardWriter(config.log_dir, self.logger,
                                            cfg_trainer['tensorboard'])

        # load checkpoint for resume training
        if config.resume is not None:
            self._resume_checkpoint(config.resume)

        # load checkpoint following load to multi-gpu, avoid 'module.' prefix
        if self.config['trainer']['sync_batch_norm'] and self.distributed:
            self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                self.model)

        if self.distributed:
            self.model = DDP(self.model,
                             device_ids=self.device_ids,
                             output_device=self.device_ids[0],
                             find_unused_parameters=True)

        self.data_loader = data_loader
        if max_len_step is None:  # max length of iteration step of every epoch
            # epoch-based training
            self.len_step = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_step = max_len_step
        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.lr_scheduler = lr_scheduler

        log_step = self.config['trainer']['log_step_interval']
        self.log_step = log_step if log_step != -1 and 0 < log_step < self.len_step else int(
            np.sqrt(data_loader.batch_size))

        val_step_interval = self.config['trainer']['val_step_interval']
        # self.val_step_interval = val_step_interval if val_step_interval!= -1 and 0 < val_step_interval < self.len_step\
        #                                             else int(np.sqrt(data_loader.batch_size))
        self.val_step_interval = val_step_interval

        self.gl_loss_lambda = self.config['trainer']['gl_loss_lambda']

        self.train_loss_metrics = MetricTracker(
            'loss',
            'gl_loss',
            'crf_loss',
            writer=self.writer if self.local_master else None)
        self.valid_f1_metrics = SpanBasedF1MetricTracker(iob_labels_vocab_cls)

    def train(self):
        """
        Full training logic, including train and validation.
        """

        if self.distributed:
            dist.barrier()  # Syncing machines before training

        not_improved_count = 0
        for epoch in range(self.start_epoch, self.epochs + 1):

            # ensure distribute worker sample different data,
            # set different random seed by passing epoch to sampler
            if self.distributed:
                self.data_loader.sampler.set_epoch(epoch)
            result_dict = self._train_epoch(epoch)

            # print logged informations to the screen
            if self.do_validation:
                val_result_dict = result_dict['val_result_dict']
                val_res = SpanBasedF1MetricTracker.dict2str(val_result_dict)
            else:
                val_res = ''
            # every epoch log information
            self.logger_info(
                '[Epoch Validation] Epoch:[{}/{}] Total Loss: {:.6f} '
                'GL_Loss: {:.6f} CRF_Loss: {:.6f} \n{}'.format(
                    epoch, self.epochs, result_dict['loss'],
                    result_dict['gl_loss'] * self.gl_loss_lambda,
                    result_dict['crf_loss'], val_res))

            # evaluate model performance according to configured metric, check early stop, and
            # save best checkpoint as model_best
            best = False
            if self.monitor_mode != 'off' and self.do_validation:
                best, not_improved_count = self._is_best_monitor_metric(
                    best, not_improved_count, val_result_dict)
                if not_improved_count > self.early_stop:
                    self.logger_info(
                        "Validation performance didn\'t improve for {} epochs. "
                        "Training stops.".format(self.early_stop))
                    break

            if epoch % self.save_period == 0:
                self._save_checkpoint(epoch, save_best=best)

    def _is_best_monitor_metric(self, best, not_improved_count,
                                val_result_dict):
        '''
        monitor metric
        :param best:
        :param not_improved_count:
        :param val_result_dict:
        :return:
        '''
        entity_name, metric = self.monitor_metric.split('-')
        val_monitor_metric_res = val_result_dict[entity_name][metric]
        try:
            # check whether model performance improved or not, according to specified metric(monitor_metric)
            improved = (self.monitor_mode == 'min' and val_monitor_metric_res <= self.monitor_best) or \
                       (self.monitor_mode == 'max' and val_monitor_metric_res >= self.monitor_best)
        except KeyError:
            self.logger_warning(
                "Warning: Metric '{}' is not found. "
                "Model performance monitoring is disabled.".format(
                    self.monitor_metric))
            self.monitor_mode = 'off'
            improved = False
        if improved:
            self.monitor_best = val_monitor_metric_res
            not_improved_count = 0
            best = True
        else:
            not_improved_count += 1
        return best, not_improved_count

    def _train_epoch(self, epoch):
        '''
        Training logic for an epoch
        :param epoch: Integer, current training epoch.
        :return: A log dict that contains average loss and metric in this epoch.
        '''
        self.model.train()
        self.train_loss_metrics.reset()
        ## step iteration start ##
        for step_idx, input_data_item in enumerate(self.data_loader):
            step_idx += 1
            for key, input_value in input_data_item.items():
                if input_value is not None and isinstance(
                        input_value, torch.Tensor):
                    input_data_item[key] = input_value.to(self.device,
                                                          non_blocking=True)
            if self.config['trainer']['anomaly_detection']:
                # This mode will increase the runtime and should only be enabled for debugging
                with torch.autograd.detect_anomaly():
                    self.optimizer.zero_grad()
                    # model forward
                    output = self.model(**input_data_item)
                    # calculate loss
                    gl_loss = output['gl_loss']
                    crf_loss = output['crf_loss']
                    total_loss = torch.sum(
                        crf_loss) + self.gl_loss_lambda * torch.sum(gl_loss)
                    # backward
                    total_loss.backward()
                    # self.average_gradients(self.model)
                    self.optimizer.step()
            else:
                self.optimizer.zero_grad()
                # model forward
                output = self.model(**input_data_item)
                # calculate loss
                gl_loss = output['gl_loss']
                crf_loss = output['crf_loss']
                total_loss = torch.sum(
                    crf_loss) + self.gl_loss_lambda * torch.sum(gl_loss)
                # backward
                total_loss.backward()
                # self.average_gradients(self.model)
                self.optimizer.step()

            # Use a barrier() to make sure that all process have finished forward and backward
            if self.distributed:
                dist.barrier()
                #  obtain the sum of all total_loss at all processes
                dist.all_reduce(total_loss, op=dist.reduce_op.SUM)

                size = dist.get_world_size()
            else:
                size = 1
            gl_loss /= size  # averages gl_loss across the whole world
            crf_loss /= size  # averages crf_loss across the whole world

            # calculate average loss across the batch size
            avg_gl_loss = torch.mean(gl_loss)
            avg_crf_loss = torch.mean(crf_loss)
            avg_loss = avg_crf_loss + self.gl_loss_lambda * avg_gl_loss
            # update metrics
            self.writer.set_step((epoch - 1) * self.len_step + step_idx -
                                 1) if self.local_master else None
            self.train_loss_metrics.update('loss', avg_loss.item())
            self.train_loss_metrics.update(
                'gl_loss',
                avg_gl_loss.item() * self.gl_loss_lambda)
            self.train_loss_metrics.update('crf_loss', avg_crf_loss.item())

            # log messages
            if step_idx % self.log_step == 0:
                self.logger_info(
                    'Train Epoch:[{}/{}] Step:[{}/{}] Total Loss: {:.6f} GL_Loss: {:.6f} CRF_Loss: {:.6f}'
                    .format(epoch, self.epochs, step_idx, self.len_step,
                            avg_loss.item(),
                            avg_gl_loss.item() * self.gl_loss_lambda,
                            avg_crf_loss.item()))
                # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))

            # do validation after val_step_interval iteration
            if self.do_validation and step_idx % self.val_step_interval == 0:
                val_result_dict = self._valid_epoch(epoch)
                self.logger_info(
                    '[Step Validation] Epoch:[{}/{}] Step:[{}/{}]  \n{}'.
                    format(epoch, self.epochs, step_idx, self.len_step,
                           SpanBasedF1MetricTracker.dict2str(val_result_dict)))

                # check if best metric, if true, then save as model_best checkpoint.
                best, not_improved_count = self._is_best_monitor_metric(
                    False, 0, val_result_dict)
                if best:
                    self._save_checkpoint(epoch, best)

            # decide whether continue iter
            if step_idx == self.len_step + 1:
                break

        ## step iteration end ##

        # {'loss': avg_loss, 'gl_loss': avg_gl_loss, 'crf_loss': avg_crf_loss}
        log = self.train_loss_metrics.result()

        # do validation after training an epoch
        if self.do_validation:
            val_result_dict = self._valid_epoch(epoch)
            log['val_result_dict'] = val_result_dict

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        return log

    def _valid_epoch(self, epoch):
        '''
         Validate after training an epoch or regular step, this is a time-consuming procedure if validation data is big.
        :param epoch: Integer, current training epoch.
        :return: A dict that contains information about validation
        '''

        self.model.eval()
        self.valid_f1_metrics.reset()
        with torch.no_grad():
            for step_idx, input_data_item in enumerate(self.valid_data_loader):
                for key, input_value in input_data_item.items():
                    if input_value is not None and isinstance(
                            input_value, torch.Tensor):
                        input_data_item[key] = input_value.to(
                            self.device, non_blocking=True)

                output = self.model(**input_data_item)
                # print("awesome 307")
                logits = output['logits']
                new_mask = output['new_mask']
                if hasattr(self.model, 'module'):
                    #  List[(List[int], torch.Tensor)] contain the tag indices of the maximum likelihood tag sequence.
                    #  and the score of the viterbi path.
                    best_paths = self.model.module.decoder.crf_layer.viterbi_tags(
                        logits, mask=new_mask, logits_batch_first=True)
                else:
                    best_paths = self.model.decoder.crf_layer.viterbi_tags(
                        logits, mask=new_mask, logits_batch_first=True)
                predicted_tags = []
                for path, score in best_paths:
                    predicted_tags.append(path)

                self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + step_idx, 'valid') \
                    if self.local_master else None

                # calculate and update f1 metrics
                # (B, N*T, out_dim)
                predicted_tags_hard_prob = logits * 0
                for i, instance_tags in enumerate(predicted_tags):
                    for j, tag_id in enumerate(instance_tags):
                        predicted_tags_hard_prob[i, j, tag_id] = 1

                golden_tags = input_data_item['iob_tags_label']
                mask = input_data_item['mask']
                union_iob_tags = iob_tags_to_union_iob_tags(golden_tags, mask)

                if self.distributed:
                    dist.barrier()  #
                self.valid_f1_metrics.update(predicted_tags_hard_prob.long(),
                                             union_iob_tags, new_mask)

        # add histogram of model parameters to the tensorboard
        # for name, p in self.model.named_parameters():
        #     self.writer.add_histogram(name, p, bins='auto')

        f1_result_dict = self.valid_f1_metrics.result()

        # rollback to train mode
        self.model.train()

        return f1_result_dict

    def average_gradients(self, model):
        '''
        Gradient averaging
        :param model:
        :return:
        '''
        size = float(dist.get_world_size())
        for param in model.parameters():
            dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM)
            param.grad.data /= size

    def logger_info(self, msg):
        self.logger.info(msg) if self.local_master else None

    def logger_warning(self, msg):
        self.logger.warning(msg) if self.local_master else None

    def _prepare_device(self, local_rank, local_world_size):
        '''
         setup GPU device if available, move model into configured device
        :param local_rank:
        :param local_world_size:
        :return:
        '''
        if self.distributed:
            ngpu_per_process = torch.cuda.device_count() // local_world_size
            device_ids = list(
                range(local_rank * ngpu_per_process,
                      (local_rank + 1) * ngpu_per_process))

            if torch.cuda.is_available() and local_rank != -1:
                torch.cuda.set_device(
                    device_ids[0]
                )  # device_ids[0] =local_rank if local_world_size = n_gpu per node
                device = 'cuda'
                self.logger_info(
                    f"[Process {os.getpid()}] world_size = {dist.get_world_size()}, "
                    +
                    f"rank = {dist.get_rank()}, n_gpu/process = {ngpu_per_process}, device_ids = {device_ids}"
                )
            else:
                self.logger_warning('Training will be using CPU!')
                device = 'cpu'
            device = torch.device(device)
            return device, device_ids
        else:
            n_gpu = torch.cuda.device_count()
            print(f"NUMBER GPU {n_gpu}")
            n_gpu_use = local_world_size
            if n_gpu_use > 0 and n_gpu == 0:
                self.logger_warning(
                    "Warning: There\'s no GPU available on this machine,"
                    "training will be performed on CPU.")
                n_gpu_use = 0
            if n_gpu_use > n_gpu:
                self.logger_warning(
                    "Warning: The number of GPU\'s configured to use is {}, but only {} are available "
                    "on this machine.".format(n_gpu_use, n_gpu))
                n_gpu_use = n_gpu

            list_ids = list(range(n_gpu_use))
            if n_gpu_use > 0:
                torch.cuda.set_device(
                    list_ids[0])  # only use first available gpu as devices
                self.logger_warning(f'Training is using GPU {list_ids[0]}!')
                device = 'cuda'
            else:
                self.logger_warning('Training is using CPU!')
                device = 'cpu'
            device = torch.device(device)
            return device, list_ids

    def _save_checkpoint(self, epoch, save_best=False):
        '''
        Saving checkpoints
        :param epoch:  current epoch number
        :param save_best: if True, rename the saved checkpoint to 'model_best.pth'
        :return:
        '''
        # only local master process do save model
        if not self.local_master:
            return

        if hasattr(self.model, 'module'):
            arch = type(self.model.module).__name__
            state_dict = self.model.module.state_dict()
        else:
            arch = type(self.model).__name__
            state_dict = self.model.state_dict()
        state = {
            'arch': arch,
            'epoch': epoch,
            'state_dict': state_dict,
            'optimizer': self.optimizer.state_dict(),
            'monitor_best': self.monitor_best,
            'config': self.config
        }
        if save_best:
            best_path = str(self.checkpoint_dir / 'model_best.pth')
            torch.save(state, best_path)
            self.logger_info("Saving current best: model_best.pth ...")
        else:
            filename = str(self.checkpoint_dir /
                           'checkpoint-epoch{}.pth'.format(epoch))
            torch.save(state, filename)
            self.logger_info("Saving checkpoint: {} ...".format(filename))

    def _resume_checkpoint(self, resume_path):
        '''
        Resume from saved checkpoints
        :param resume_path: Checkpoint path to be resumed
        :return:
        '''
        resume_path = str(resume_path)
        self.logger_info("Loading checkpoint: {} ...".format(resume_path))
        # map_location = {'cuda:%d' % 0: 'cuda:%d' % self.config['local_rank']}
        checkpoint = torch.load(resume_path, map_location=self.device)
        self.start_epoch = checkpoint['epoch'] + 1
        self.monitor_best = checkpoint['monitor_best']

        # load architecture params from checkpoint.
        if checkpoint['config']['model_arch'] != self.config['model_arch']:
            self.logger_warning(
                "Warning: Architecture configuration given in config file is different from that of "
                "checkpoint. This may yield an exception while state_dict is being loaded."
            )
        self.model.load_state_dict(checkpoint['state_dict'])

        # load optimizer state from checkpoint only when optimizer type is not changed.
        if checkpoint['config']['optimizer']['type'] != self.config[
                'optimizer']['type']:
            self.logger_warning(
                "Warning: Optimizer type given in config file is different from that of checkpoint. "
                "Optimizer parameters not being resumed.")
        else:
            self.optimizer.load_state_dict(checkpoint['optimizer'])

        self.logger_info(
            "Checkpoint loaded. Resume training from epoch {}".format(
                self.start_epoch))
Ejemplo n.º 8
0
def train(trainloader, modelS1, modelS2, optimizerS1, optimizerS2, epoch,
          train_writer, gpuDisabled, resultsFile_name):

    lossTracker = MetricTracker()

    modelS1.train()
    modelS2.train()

    for idx, (dataS1, dataS2) in enumerate(tqdm(trainloader, desc="training")):
        numSample = dataS2["bands10"].size(0)

        if not args.lossFunc == 'TripletLoss':
            lossFunc = nn.MSELoss()

            halfNumSample = numSample // 2

            if gpuDisabled:
                bands1 = torch.cat((dataS2["bands10"][:halfNumSample],
                                    dataS2["bands20"][:halfNumSample],
                                    dataS2["bands60"][:halfNumSample]),
                                   dim=1).to(torch.device("cpu"))
                polars1 = torch.cat((dataS1["polarVH"][:halfNumSample],
                                     dataS1["polarVV"][:halfNumSample]),
                                    dim=1).to(torch.device("cpu"))
                labels1 = dataS2["label"][:halfNumSample].to(
                    torch.device("cpu"))

                bands2 = torch.cat((dataS2["bands10"][halfNumSample:],
                                    dataS2["bands20"][halfNumSample:],
                                    dataS2["bands60"][halfNumSample:]),
                                   dim=1).to(torch.device("cpu"))
                polars2 = torch.cat((dataS1["polarVH"][halfNumSample:],
                                     dataS1["polarVV"][halfNumSample:]),
                                    dim=1).to(torch.device("cpu"))
                labels2 = dataS2["label"][halfNumSample:].to(
                    torch.device("cpu"))

                labels = torch.cat((labels1, labels2)).to(torch.device('cpu'))

                onesTensor = torch.ones(halfNumSample)

            else:
                bands1 = torch.cat((dataS2["bands10"][:halfNumSample],
                                    dataS2["bands20"][:halfNumSample],
                                    dataS2["bands60"][:halfNumSample]),
                                   dim=1).to(torch.device("cuda"))
                polars1 = torch.cat((dataS1["polarVH"][:halfNumSample],
                                     dataS1["polarVV"][:halfNumSample]),
                                    dim=1).to(torch.device("cuda"))
                labels1 = dataS2["label"][:halfNumSample].to(
                    torch.device("cuda"))

                bands2 = torch.cat((dataS2["bands10"][halfNumSample:],
                                    dataS2["bands20"][halfNumSample:],
                                    dataS2["bands60"][halfNumSample:]),
                                   dim=1).to(torch.device("cuda"))
                polars2 = torch.cat((dataS1["polarVH"][halfNumSample:],
                                     dataS1["polarVV"][halfNumSample:]),
                                    dim=1).to(torch.device("cuda"))
                labels2 = dataS2["label"][halfNumSample:].to(
                    torch.device("cuda"))

                labels = torch.cat((labels1, labels2)).to(torch.device('cuda'))

                onesTensor = torch.cuda.FloatTensor(halfNumSample).fill_(0)

            optimizerS1.zero_grad()
            optimizerS2.zero_grad()

            logitsS1_1 = modelS1(polars1)
            logitsS1_2 = modelS1(polars2)

            logitsS2_1 = modelS2(bands1)
            logitsS2_2 = modelS2(bands2)

            cos = torch.nn.CosineSimilarity(dim=1)
            cosBetweenLabels = cos(labels1, labels2)

            cosBetweenS1 = cos(logitsS1_1, logitsS1_2)
            cosBetweenS2 = cos(logitsS2_1, logitsS2_2)

            cosInterSameLabel1 = cos(logitsS1_1, logitsS2_1)
            cosInterSameLabel2 = cos(logitsS1_2, logitsS2_2)

            cosInterDifLabel1 = cos(logitsS1_1, logitsS2_2)
            cosInterDifLabel2 = cos(logitsS1_2, logitsS2_1)

            S1IntraLoss = lossFunc(cosBetweenS1, cosBetweenLabels)
            S2IntraLoss = lossFunc(cosBetweenS2, cosBetweenLabels)

            InterLoss_SameLabel1 = lossFunc(cosInterSameLabel1, onesTensor)
            InterLoss_SameLabel2 = lossFunc(cosInterSameLabel2, onesTensor)

            InterLoss_DifLabel1 = lossFunc(cosInterDifLabel1, cosBetweenLabels)
            InterLoss_DifLabel2 = lossFunc(cosInterDifLabel2, cosBetweenLabels)

            mseLoss = 0.33 * S1IntraLoss + 0.33 * S2IntraLoss + 0.0825 * InterLoss_SameLabel1 + 0.0825 * InterLoss_SameLabel2 + 0.0825 * InterLoss_DifLabel1 * 0.0825 * InterLoss_DifLabel2

            pushLossValue = pushLossInMSE(logitsS1_1, logitsS1_2, logitsS2_1,
                                          logitsS2_2)
            balancingLossValue = balancingLossInMSE(logitsS1_1, logitsS1_2,
                                                    logitsS2_1, logitsS2_2)

            loss = mseLoss - beta * pushLossValue / args.bits + gamma * balancingLossValue

        else:
            if gpuDisabled:
                bands = torch.cat(
                    (dataS2["bands10"], dataS2["bands20"], dataS2["bands60"]),
                    dim=1).to(torch.device("cpu"))
                polars = torch.cat((dataS1["polarVH"], dataS1["polarVV"]),
                                   dim=1).to(torch.device("cpu"))
                labels = dataS2["label"].to(torch.device("cpu"))

            else:
                bands = torch.cat(
                    (dataS2["bands10"], dataS2["bands20"], dataS2["bands60"]),
                    dim=1).to(torch.device("cuda"))
                polars = torch.cat((dataS1["polarVH"], dataS1["polarVV"]),
                                   dim=1).to(torch.device("cuda"))
                labels = dataS2["label"].to(torch.device("cuda"))

            optimizerS1.zero_grad()
            optimizerS2.zero_grad()

            logitsS1 = modelS1(polars)
            logitsS2 = modelS2(bands)

            pushLossValue = pushLoss(logitsS1, logitsS2)
            balancingLossValue = balancingLoss(logitsS1, logitsS2)

            triplets = get_triplets(labels)

            S1IntraLoss = triplet_loss(logitsS1[triplets[0]],
                                       logitsS1[triplets[1]],
                                       logitsS1[triplets[2]])
            S2IntraLoss = triplet_loss(logitsS2[triplets[0]],
                                       logitsS2[triplets[1]],
                                       logitsS2[triplets[2]])

            InterLoss1 = triplet_loss(logitsS1[triplets[0]],
                                      logitsS2[triplets[1]],
                                      logitsS2[triplets[2]])
            InterLoss2 = triplet_loss(logitsS2[triplets[0]],
                                      logitsS1[triplets[1]],
                                      logitsS1[triplets[2]])

            tripletLoss = 0.25 * S1IntraLoss + 0.25 * S2IntraLoss + 0.25 * InterLoss1 + 0.25 * InterLoss2

            loss = tripletLoss - beta * pushLossValue / args.bits + gamma * balancingLossValue

        loss.backward()
        optimizerS1.step()
        optimizerS2.step()

        lossTracker.update(loss.item(), numSample)

    train_writer.add_scalar("loss", lossTracker.avg, epoch)

    print('Train loss: {:.6f}'.format(lossTracker.avg))
    with open(resultsFile_name, 'a') as resultsFile:
        resultsFile.write('Train loss: {:.6f}\n'.format(lossTracker.avg))
Ejemplo n.º 9
0
def train(trainloader, model, lemniscate, criterion, CELoss, lambda_, optimizer, epoch, train_writer):

    losses = MetricTracker()
    sncalosses = MetricTracker()
    celosses = MetricTracker()

    model.train()

    for idx, data in enumerate(tqdm(trainloader, desc="training")):

        imgs = data['img'].to(torch.device("cuda"))
        index = data["idx"].to(torch.device("cuda"))
        labels = data['label'].to(torch.device("cuda"))

        feature, logits = model(imgs)
        output = lemniscate(feature, index)
        
        sncaloss = criterion(output, index)
        celoss = CELoss(logits, labels)

        loss = lambda_ * sncaloss + celoss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.update(loss.item(), imgs.size(0))
        sncalosses.update(sncaloss.item(), imgs.size(0))
        celosses.update(celoss.item(), imgs.size(0))

    info = {
        "Loss": losses.avg,
        "SNCALoss": sncalosses.avg,
        "CELoss": celosses.avg
    }
    for tag, value in info.items():
        train_writer.add_scalar(tag, value, epoch)

    print('Train TotalLoss: {:.6f} SNCALoss: {:.6f} CELoss: {:.6f}'.format(
            losses.avg,
            sncalosses.avg,
            celosses.avg
            ))
Ejemplo n.º 10
0
def train(trainloader, train_test_loader, model, optimizer, MNCELoss, epoch,
          train_writer):

    global iter_train_test_loader

    total_losses = MetricTracker()
    MNCE_losses = MetricTracker()
    BNM_losses = MetricTracker()

    model.train()

    for idx, train_data in enumerate(tqdm(trainloader, desc="training")):
        iter_num = epoch * len(trainloader) + idx

        if iter_num % len(train_test_loader) == 0:
            iter_train_test_loader = iter(train_test_loader)

        train_imgs = train_data['img'].to(torch.device("cuda"))
        train_labels = train_data['label'].to(torch.device("cuda"))

        test_data = iter_train_test_loader.next()

        test_imgs = test_data['img'].to(torch.device("cuda"))
        test_labels = test_data['label'].to(torch.device("cuda"))

        train_e = model(train_imgs)

        _, mnceloss = MNCELoss(train_e, train_labels)

        test_e = model(test_imgs)

        testlogits, _ = MNCELoss(test_e, test_labels)
        softmax_tgt = nn.Softmax(dim=1)(testlogits)

        _, s_tgt, _ = torch.svd(softmax_tgt)
        transfer_loss = -torch.mean(s_tgt)

        total_loss = mnceloss + transfer_loss

        optimizer.zero_grad()

        total_loss.backward()
        optimizer.step()

        total_losses.update(total_loss.item(), train_imgs.size(0))
        MNCE_losses.update(mnceloss.item(), train_imgs.size(0))
        BNM_losses.update(transfer_loss.item(), train_imgs.size(0))

    info = {
        "Loss": total_losses.avg,
        "MNCE": MNCE_losses.avg,
        "BNM": BNM_losses.avg
    }

    for tag, value in info.items():
        train_writer.add_scalar(tag, value, epoch)

    print(
        f'Train TotalLoss: {total_losses.avg:.6f} NCE loss: {MNCE_losses.avg:.6f} BNM loss: {BNM_losses.avg:.6f}'
    )
Ejemplo n.º 11
0
def train_Moco(trainloader, model, model_ema, lemniscate, criterion, CELoss,
               optimizer, epoch, train_writer):

    losses = MetricTracker()
    sncalosses = MetricTracker()
    celosses = MetricTracker()

    model.train()
    model_ema.eval()

    def set_bn_train(m):
        classname = m.__class__.__name__
        if classname.find('BatchNorm') != -1:
            m.train()

    model_ema.apply(set_bn_train)

    for idx, data in enumerate(tqdm(trainloader, desc="training")):

        imgs = data['img'].to(torch.device("cuda"))
        index = data["idx"].to(torch.device("cuda"))
        labels = data['label'].to(torch.device("cuda"))

        bsz = imgs.size(0)

        shuffle_ids, reverse_ids = get_shuffle_ids(bsz)

        feature, logits = model(imgs)

        with torch.no_grad():
            imgs = imgs[shuffle_ids]
            feature_hat, _ = model_ema(imgs)
            feature_hat = feature_hat[reverse_ids]

        output = lemniscate(feature, feature_hat, index)
        snca_moco_loss = criterion(output, index)

        celoss = CELoss(logits, labels)

        loss = snca_moco_loss + celoss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.update(loss.item(), imgs.size(0))
        sncalosses.update(snca_moco_loss.item(), imgs.size(0))
        celosses.update(celoss.item(), imgs.size(0))

        moment_update(model, model_ema, 0.999)

    info = {
        "Loss": losses.avg,
        "SNCALoss": sncalosses.avg,
        "CELoss": celosses.avg
    }
    for tag, value in info.items():
        train_writer.add_scalar(tag, value, epoch)

    print('Train TotalLoss: {:.6f} SNCALoss: {:.6f} CELoss: {:.6f}'.format(
        losses.avg, sncalosses.avg, celosses.avg))