Exemple #1
0
class Trainer(object):
    def __init__(self, args):
        self.config = args
        # parameters
        self.start_epoch = 1
        self.max_epoch = args.max_epoch
        self.save_dir = args.save_dir
        self.device = args.device
        self.verbose = args.verbose
        self.max_points = args.max_points
        self.voxel_size = args.voxel_size

        self.model = args.model.to(self.device)
        self.optimizer = args.optimizer
        self.scheduler = args.scheduler
        self.scheduler_freq = args.scheduler_freq
        self.snapshot_freq = args.snapshot_freq
        self.snapshot_dir = args.snapshot_dir
        self.benchmark = args.benchmark
        self.iter_size = args.iter_size
        self.verbose_freq = args.verbose_freq

        self.w_circle_loss = args.w_circle_loss
        self.w_overlap_loss = args.w_overlap_loss
        self.w_saliency_loss = args.w_saliency_loss
        self.desc_loss = args.desc_loss

        self.best_loss = 1e5
        self.best_recall = -1e5
        self.writer = SummaryWriter(log_dir=args.tboard_dir)
        self.logger = Logger(args.snapshot_dir)
        self.logger.write(
            f'#parameters {sum([x.nelement() for x in self.model.parameters()])/1000000.} M\n'
        )

        if (args.pretrain != ''):
            self._load_pretrain(args.pretrain)

        self.loader = dict()
        self.loader['train'] = args.train_loader
        self.loader['val'] = args.val_loader
        self.loader['test'] = args.test_loader

        with open(f'{args.snapshot_dir}/model', 'w') as f:
            f.write(str(self.model))
        f.close()

    def _snapshot(self, epoch, name=None):
        state = {
            'epoch': epoch,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict(),
            'best_loss': self.best_loss,
            'best_recall': self.best_recall
        }
        if name is None:
            filename = os.path.join(self.save_dir, f'model_{epoch}.pth')
        else:
            filename = os.path.join(self.save_dir, f'model_{name}.pth')
        self.logger.write(f"Save model to {filename}\n")
        torch.save(state, filename)

    def _load_pretrain(self, resume):
        if os.path.isfile(resume):
            state = torch.load(resume)
            self.model.load_state_dict(state['state_dict'])
            self.start_epoch = state['epoch']
            self.scheduler.load_state_dict(state['scheduler'])
            self.optimizer.load_state_dict(state['optimizer'])
            self.best_loss = state['best_loss']
            self.best_recall = state['best_recall']

            self.logger.write(
                f'Successfully load pretrained model from {resume}!\n')
            self.logger.write(f'Current best loss {self.best_loss}\n')
            self.logger.write(f'Current best recall {self.best_recall}\n')
        else:
            raise ValueError(f"=> no checkpoint found at '{resume}'")

    def _get_lr(self, group=0):
        return self.optimizer.param_groups[group]['lr']

    def stats_dict(self):
        stats = dict()
        stats['circle_loss'] = 0.
        stats[
            'recall'] = 0.  # feature match recall, divided by number of ground truth pairs
        stats['saliency_loss'] = 0.
        stats['saliency_recall'] = 0.
        stats['saliency_precision'] = 0.
        stats['overlap_loss'] = 0.
        stats['overlap_recall'] = 0.
        stats['overlap_precision'] = 0.
        return stats

    def stats_meter(self):
        meters = dict()
        stats = self.stats_dict()
        for key, _ in stats.items():
            meters[key] = AverageMeter()
        return meters

    def inference_one_batch(self, input_dict, phase):
        assert phase in ['train', 'val', 'test']
        ##################################
        # training
        if (phase == 'train'):
            self.model.train()
            ###############################################
            # forward pass
            sinput_src = ME.SparseTensor(input_dict['src_F'].to(self.device),
                                         coordinates=input_dict['src_C'].to(
                                             self.device))
            sinput_tgt = ME.SparseTensor(input_dict['tgt_F'].to(self.device),
                                         coordinates=input_dict['tgt_C'].to(
                                             self.device))

            src_feats, tgt_feats, scores_overlap, scores_saliency = self.model(
                sinput_src, sinput_tgt)
            src_pcd, tgt_pcd = input_dict['pcd_src'].to(
                self.device), input_dict['pcd_tgt'].to(self.device)
            c_rot = input_dict['rot'].to(self.device)
            c_trans = input_dict['trans'].to(self.device)
            correspondence = input_dict['correspondences'].long().to(
                self.device)

            ###################################################
            # get loss
            stats = self.desc_loss(src_pcd, tgt_pcd, src_feats, tgt_feats,
                                   correspondence, c_rot, c_trans,
                                   scores_overlap, scores_saliency,
                                   input_dict['scale'])

            c_loss = stats['circle_loss'] * self.w_circle_loss + stats[
                'overlap_loss'] * self.w_overlap_loss + stats[
                    'saliency_loss'] * self.w_saliency_loss

            c_loss.backward()

        else:
            self.model.eval()
            with torch.no_grad():
                ###############################################
                # forward pass
                sinput_src = ME.SparseTensor(
                    input_dict['src_F'].to(self.device),
                    coordinates=input_dict['src_C'].to(self.device))
                sinput_tgt = ME.SparseTensor(
                    input_dict['tgt_F'].to(self.device),
                    coordinates=input_dict['tgt_C'].to(self.device))

                src_feats, tgt_feats, scores_overlap, scores_saliency = self.model(
                    sinput_src, sinput_tgt)
                src_pcd, tgt_pcd = input_dict['pcd_src'].to(
                    self.device), input_dict['pcd_tgt'].to(self.device)
                c_rot = input_dict['rot'].to(self.device)
                c_trans = input_dict['trans'].to(self.device)
                correspondence = input_dict['correspondences'].long().to(
                    self.device)

                ###################################################
                # get loss
                stats = self.desc_loss(src_pcd, tgt_pcd, src_feats, tgt_feats,
                                       correspondence, c_rot, c_trans,
                                       scores_overlap, scores_saliency,
                                       input_dict['scale'])

        ##################################
        # detach the gradients for loss terms
        stats['circle_loss'] = float(stats['circle_loss'].detach())
        stats['overlap_loss'] = float(stats['overlap_loss'].detach())
        stats['saliency_loss'] = float(stats['saliency_loss'].detach())

        return stats

    def inference_one_epoch(self, epoch, phase):
        gc.collect()
        assert phase in ['train', 'val', 'test']

        # init stats meter
        stats_meter = self.stats_meter()

        num_iter = int(
            len(self.loader[phase].dataset) // self.loader[phase].batch_size)
        c_loader_iter = self.loader[phase].__iter__()

        self.optimizer.zero_grad()
        for c_iter in tqdm(range(num_iter)):  # loop through this epoch
            inputs = c_loader_iter.next()
            try:
                ##################################
                # forward pass
                # with torch.autograd.detect_anomaly():
                stats = self.inference_one_batch(inputs, phase)

                ###################################################
                # run optimisation
                if ((c_iter + 1) % self.iter_size == 0 and phase == 'train'):
                    gradient_valid = validate_gradient(self.model)
                    if (gradient_valid):
                        self.optimizer.step()
                    else:
                        self.logger.write('gradient not valid\n')
                    self.optimizer.zero_grad()

                ################################
                # update to stats_meter
                for key, value in stats.items():
                    stats_meter[key].update(value)
            except RuntimeError as inst:
                pass

            torch.cuda.empty_cache()

            if (c_iter + 1) % self.verbose_freq == 0 and self.verbose:
                curr_iter = num_iter * (epoch - 1) + c_iter
                for key, value in stats_meter.items():
                    self.writer.add_scalar(f'{phase}/{key}', value.avg,
                                           curr_iter)

                message = f'{phase} Epoch: {epoch} [{c_iter+1:4d}/{num_iter}]'
                for key, value in stats_meter.items():
                    message += f'{key}: {value.avg:.2f}\t'

                self.logger.write(message + '\n')

        message = f'{phase} Epoch: {epoch}'
        for key, value in stats_meter.items():
            message += f'{key}: {value.avg:.2f}\t'
        self.logger.write(message + '\n')

        return stats_meter

    def train(self):
        print('start training...')
        for epoch in range(self.start_epoch, self.max_epoch):
            self.inference_one_epoch(epoch, 'train')
            self.scheduler.step()

            stats_meter = self.inference_one_epoch(epoch, 'val')

            if stats_meter['circle_loss'].avg < self.best_loss:
                self.best_loss = stats_meter['circle_loss'].avg
                self._snapshot(epoch, 'best_loss')
            if stats_meter['recall'].avg > self.best_recall:
                self.best_recall = stats_meter['recall'].avg
                self._snapshot(epoch, 'best_recall')

            # we only add saliency loss when we get descent point-wise features
            if (stats_meter['recall'].avg > 0.3):
                self.w_saliency_loss = 1.
            else:
                self.w_saliency_loss = 0.

        # finish all epoch
        print("Training finish!")

    def eval(self):
        print('Start to evaluate on validation datasets...')
        stats_meter = self.inference_one_epoch(0, 'val')

        for key, value in stats_meter.items():
            print(key, value.avg)
Exemple #2
0
    num_workers=8,
    pin_memory=True
)
train_loader = torch.utils.data.DataLoader(
    COCO(cfg=opt, split='train',augment=True),
    batch_size=opt.batch_size,
    shuffle=True,
    num_workers=8,
    pin_memory=True
)

print('Starting training...')
best = 1e10
for epoch in range(start_epoch + 1, opt.num_epochs + 1):
    log_dict_train, _ = trainer.train(epoch, train_loader)
    logger.write('epoch: {} |'.format(epoch))
    for k, v in log_dict_train.items():
        logger.scalar_summary('train_{}'.format(k), v, epoch)
        logger.write('{} {:8f} | '.format(k, v))
    with torch.no_grad():
        log_dict_val, preds = trainer.val(epoch, val_loader)
    for k, v in log_dict_val.items():
        logger.scalar_summary('val_{}'.format(k), v, epoch)
        logger.write('{} {:8f} | '.format(k, v))
    if log_dict_val['loss'] < best:
        best = log_dict_val['loss']
        save_model(os.path.join(opt.save_dir, 'model_best.pth'),
               epoch, model)
    save_model(os.path.join(opt.save_dir, 'model_last.pth'),
             epoch, model, optimizer)
    logger.write('\n')
Exemple #3
0
def train(word_emb,
          vision_model,
          language_model,
          ent_loss_model,
          rel_loss_model,
          train_loader,
          val_loader,
          word_dict,
          ent_dict,
          pred_dict,
          n_epochs,
          val_freq,
          out_dir,
          cfg,
          grad_freq=0):

    os.makedirs(out_dir, exist_ok=True)
    params = list(vision_model.parameters()) + list(
        language_model.parameters())
    params = [param for param in params if param.requires_grad]
    named_params = list(vision_model.named_parameters()) + list(
        language_model.named_parameters())
    optimizer = torch.optim.Adam(params,
                                 lr=cfg.train.learning_rate,
                                 weight_decay=cfg.train.weight_decay)
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, step_size=1, gamma=cfg.train.learning_rate_decay)
    logger = Logger(os.path.join(out_dir, "log.txt"))
    tfb_logger = TFBLogger(out_dir)
    if grad_freq > 0: plt.ion()
    n_batches = len(train_loader)
    step = 0

    for epoch in range(n_epochs):

        scheduler.step()
        epoch_loss = 0.0

        if epoch % val_freq == 0:
            vision_model.train(False)
            language_model.train(False)
            ent_acc, rel_acc = validate(word_emb, vision_model, language_model,
                                        val_loader, word_dict, ent_dict,
                                        pred_dict,
                                        cfg.language_model.tokens_length,
                                        tfb_logger, step)
            logstr = "epoch %2d | ent acc(top20): %.3f | rel acc(top20): %.3f" % (
                epoch, ent_acc, rel_acc)
            logger.write("%-80s" % logstr)
            vision_model.train(True)
            language_model.train(True)

        tic_0 = time.time()

        for i, data in enumerate(train_loader):

            tic_1 = time.time()

            image_ids = data[0]
            if len(image_ids) < cfg.train.batch_size: continue
            images = data[1].cuda().float()
            sbj_boxes = data[2].cuda().float()
            obj_boxes = data[3].cuda().float()
            rel_boxes = data[4].cuda().float()
            sbj_tokens = data[5].cuda()
            obj_tokens = data[6].cuda()
            rel_tokens = data[7].cuda()
            sbj_seq_lens = data[8].cuda().long()
            obj_seq_lens = data[9].cuda().long()
            rel_seq_lens = data[10].cuda().long()

            tic_2 = time.time()

            optimizer.zero_grad()

            sbj_t_emb = language_model(word_emb(sbj_tokens), sbj_seq_lens)
            obj_t_emb = language_model(word_emb(obj_tokens), obj_seq_lens)
            rel_t_emb = language_model(word_emb(rel_tokens), rel_seq_lens)
            sbj_v_emb, obj_v_emb, rel_v_emb = vision_model(
                images, sbj_boxes, obj_boxes, rel_boxes)

            sbj_loss = ent_loss_model(sbj_v_emb, sbj_t_emb)
            obj_loss = ent_loss_model(obj_v_emb, obj_t_emb)
            rel_loss = rel_loss_model(rel_v_emb, rel_t_emb)

            loss = sbj_loss + obj_loss + rel_loss

            tic_3 = time.time()

            loss.backward()
            optimizer.step()

            tic_4 = time.time()

            if grad_freq > 0 and i % grad_freq == 0:
                for n, p in named_params:
                    if not "bias" in n:
                        name_path = n.replace(".", "/")
                        tfb_logger.histo_summary("grad/%s" % name_path,
                                                 p.grad.data.cpu().numpy(),
                                                 step)

            epoch_loss += loss.item() * train_loader.batch_size

            logstr = "epoch %2d batch %4d/%4d | loss %5.2f | %4dms | ^ %4dms | => %4dms" % \
                     (epoch+1, i+1, n_batches, loss.item(),
                      1000*(tic_4-tic_0), 1000*(tic_2-tic_0), 1000*(tic_4-tic_2))
            print("%-80s" % logstr, end="\r")
            tfb_logger.scalar_summary("loss/ent",
                                      sbj_loss.item() + obj_loss.item(), step)
            tfb_logger.scalar_summary("loss/rel", rel_loss.item(), step)
            tfb_logger.scalar_summary("loss/total", loss.item(), step)

            tic_0 = time.time()
            step += train_loader.batch_size

        epoch_loss /= n_batches * train_loader.batch_size

        logstr = "epoch %2d | train_loss: %.3f" % (epoch + 1, epoch_loss)
        logger.write("%-80s" % logstr)

        vision_model_path = os.path.join(out_dir,
                                         "vision_model_%d.pth" % (epoch + 1))
        torch.save(vision_model.state_dict(), vision_model_path)
        language_model_path = os.path.join(
            out_dir, "language_model_%d.pth" % (epoch + 1))
        torch.save(language_model.state_dict(), language_model_path)