Exemple #1
0
def train(train_loader, model, optimizer, lr_scheduler, tb_writer):

    cur_lr = lr_scheduler.get_cur_lr()

    #rank = get_rank()
    rank = 0

    average_meter = AverageMeter()

    def is_valid_number(x):
        return not (math.isnan(x) or math.isinf(x) or x > 1e4)

    #world_size = get_world_size()
    world_size = 1
    length = len(train_loader.dataset)  # 64000*50=3200000
    num_per_epoch = len(
        train_loader.dataset) // cfg.TRAIN.EPOCH // (cfg.TRAIN.BATCH_SIZE)

    start_epoch = cfg.TRAIN.START_EPOCH
    epoch = start_epoch

    if not os.path.exists(cfg.TRAIN.SNAPSHOT_DIR) and rank == 0:
        os.makedirs(cfg.TRAIN.SNAPSHOT_DIR)

    #logger.info("model\n{}".format(describe(model.module)))
    end = time.time()
    for idx, data in enumerate(train_loader):  # idx 每个epoch要迭代的次数

        if epoch != idx // num_per_epoch + start_epoch:  #一个epoch迭代完成

            epoch = idx // num_per_epoch + start_epoch
            if rank == 0:
                torch.save(
                    {
                        'epoch': epoch,
                        'state_dict': model.module.state_dict(),
                        'optimizer': optimizer.state_dict()
                    },
                    cfg.TRAIN.SNAPSHOT_DIR + '/checkpoint_e%d.pth' % (epoch))

            if epoch == cfg.TRAIN.EPOCH:  # 所有epoch迭代完成
                return

            if cfg.BACKBONE.TRAIN_EPOCH == epoch:
                logger.info('start training backbone.')
                optimizer, lr_scheduler = build_opt_lr(model.module, epoch)
                logger.info("model\n{}".format(describe(model.module)))

            lr_scheduler.step(epoch)
            cur_lr = lr_scheduler.get_cur_lr()
            logger.info('epoch: {}'.format(epoch + 1))

        tb_idx = idx
        if idx % num_per_epoch == 0 and idx != 0:
            for idx, pg in enumerate(optimizer.param_groups):
                logger.info('epoch {} lr {}'.format(epoch + 1, pg['lr']))
                if rank == 0:
                    tb_writer.add_scalar('lr/group{}'.format(idx + 1),
                                         pg['lr'], tb_idx)

        # data_time = average_reduce(time.time() - end)
        data_time = time.time() - end

        if rank == 0:
            tb_writer.add_scalar('time/data', data_time, tb_idx)

        outputs = model(data)
        loss = outputs['total_loss'].mean()

        if is_valid_number(loss.data.item()):
            optimizer.zero_grad()
            loss.backward()
            #reduce_gradients(model)

            if rank == 0 and cfg.TRAIN.LOG_GRADS:
                log_grads(model.module, tb_writer, tb_idx)

            # clip gradient
            clip_grad_norm_(model.parameters(), cfg.TRAIN.GRAD_CLIP)
            optimizer.step()

        batch_time = time.time() - end

        batch_info = {}
        #batch_info['batch_time'] = average_reduce(batch_time)
        batch_info['batch_time'] = batch_time
        #batch_info['data_time'] = average_reduce(data_time)
        batch_info['data_time'] = data_time
        for k, v in outputs.items():
            #batch_info[k] = average_reduce(v.data.item())
            batch_info[k] = outputs[k].mean()
        average_meter.update(**batch_info)

        if rank == 0:
            for k, v in batch_info.items():
                tb_writer.add_scalar(k, v, tb_idx)

            if (idx + 1) % cfg.TRAIN.PRINT_FREQ == 0:
                # info = "Epoch: [{}][{}/{}] lr: {:.6f}\n".format(
                #             epoch+1, (idx+1) % num_per_epoch,
                #             num_per_epoch, cur_lr)
                info = "Epoch: [{}][{}/{}] lr: {:6f}\n".format(
                    epoch + 1, (idx + 1) % num_per_epoch, num_per_epoch,
                    cur_lr)
                for cc, (k, v) in enumerate(batch_info.items()):
                    if cc % 2 == 0:
                        info += ("\t{:s}\t").format(getattr(average_meter, k))
                    else:
                        info += ("{:s}\n").format(getattr(average_meter, k))
                logger.info(info)
                print_speed(idx + 1 + start_epoch * num_per_epoch,
                            average_meter.batch_time.avg,
                            cfg.TRAIN.EPOCH * num_per_epoch, num_per_epoch)
        end = time.time()
def train(train_loader, model, optimizer, lr_scheduler, tb_writer):
    cur_lr = lr_scheduler.get_cur_lr()
    # rank = get_rank()

    average_meter = AverageMeter()

    model.train()
    model.module.backbone.eval()
    model.module.neck.eval()
    model.module.rpn_head.eval()
    model.module.backbone.apply(BNtoFixed)
    model.module.neck.apply(BNtoFixed)
    model.module.rpn_head.apply(BNtoFixed)

    # train mask_model和refine_model
    model.module.mask_head.train()
    model.module.refine_head.train()
    model = model.cuda()

    def is_valid_number(x):
        return not (math.isnan(x) or math.isinf(x) or x > 1e4)

    # world_size = get_world_size()
    num_per_epoch = len(train_loader.dataset) // \
        cfg.TRAIN.EPOCH // cfg.TRAIN.BATCH_SIZE
    start_epoch = cfg.TRAIN.START_EPOCH
    epoch = start_epoch

    if not os.path.exists(cfg.TRAIN.SNAPSHOT_DIR):
        os.makedirs(cfg.TRAIN.SNAPSHOT_DIR)

    print('******')
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(name)
    print('******')

    logger.info("model\n{}".format(describe(model.module)))
    end = time.time()
    for idx, data in enumerate(train_loader):
        if epoch != idx // num_per_epoch + start_epoch:
            epoch = idx // num_per_epoch + start_epoch

            torch.save(
                {
                    'epoch': epoch,
                    'state_dict': model.module.state_dict(),
                    'optimizer': optimizer.state_dict()
                }, cfg.TRAIN.SNAPSHOT_DIR + '/checkpoint_e%d.pth' % (epoch))

            if epoch == cfg.TRAIN.EPOCH:
                return

            # if cfg.BACKBONE.TRAIN_EPOCH == epoch:
            #     logger.info('start training backbone.')
            #     optimizer, lr_scheduler = build_opt_lr(model.module, epoch)
            #     logger.info("model\n{}".format(describe(model.module)))

            optimizer, lr_scheduler = build_opt_lr(model.module, epoch)

            lr_scheduler.step(epoch)
            cur_lr = lr_scheduler.get_cur_lr()
            logger.info('epoch: {}'.format(epoch + 1))

        tb_idx = idx
        if idx % num_per_epoch == 0 and idx != 0:
            for idx, pg in enumerate(optimizer.param_groups):
                logger.info('epoch {} lr {}'.format(epoch + 1, pg['lr']))
                tb_writer.add_scalar('lr/group{}'.format(idx + 1), pg['lr'],
                                     tb_idx)

        data_time = time.time() - end
        tb_writer.add_scalar('time/data', data_time, tb_idx)

        outputs = model(data)
        # loss = outputs['total_loss']

        loss = torch.mean(outputs['total_loss'])

        if is_valid_number(loss.data.item()):
            optimizer.zero_grad()
            loss.backward()
            # reduce_gradients(model)

            if cfg.TRAIN.LOG_GRADS:
                log_grads(model.module, tb_writer, tb_idx)

            # clip gradient
            clip_grad_norm_(model.parameters(), cfg.TRAIN.GRAD_CLIP)
            optimizer.step()

        batch_time = time.time() - end
        batch_info = {}
        batch_info['batch_time'] = batch_time
        batch_info['data_time'] = data_time
        for k, v in sorted(outputs.items()):
            batch_info[k] = v.data.item()

        average_meter.update(**batch_info)

        for k, v in batch_info.items():
            tb_writer.add_scalar(k, v, tb_idx)

        if (idx + 1) % cfg.TRAIN.PRINT_FREQ == 0:
            info = "Epoch: [{}][{}/{}] lr: {:.6f}\n".format(
                epoch + 1, (idx + 1) % num_per_epoch, num_per_epoch, cur_lr)
            for cc, (k, v) in enumerate(batch_info.items()):
                if cc % 2 == 0:
                    info += ("\t{:s}\t").format(getattr(average_meter, k))
                else:
                    info += ("{:s}\n").format(getattr(average_meter, k))
            logger.info(info)
            print_speed(idx + 1 + start_epoch * num_per_epoch,
                        average_meter.batch_time.avg,
                        cfg.TRAIN.EPOCH * num_per_epoch)
        end = time.time()