def init_log(self,
              data_bunch: db.VideoDataBunch = None,
              model: nn.Module = None,
              criterion: nn.Module = None,
              optimizer: optim.Adam = None,
              lr_scheduler: lrs.MultiStepLR = None):
     self.print_options()
     if self.opts.debug:
         self.log(str(data_bunch))
         self.log(str(model))
         self.log(str(criterion))
         self.log(str(optimizer))
         self.log(str(lr_scheduler.state_dict() if lr_scheduler else None))
Пример #2
0
def train_general(args):
    args.optimizer = 'Adam'
    args.n_classes = 2
    args.batch_size = 8
    # os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
    # print(args.model_name)
    # print(args.test)
    if args.model_name == 'FCNet':
        model = FCNet(args).cuda()
        model = torch.nn.DataParallel(model)
        if args.optimizer == 'SGD':
            optimizer = SGD(model.parameters(),
                            .1,
                            weight_decay=5e-4,
                            momentum=.99)
        elif args.optimizer == 'Adam':
            optimizer = Adam(model.parameters(), .1, weight_decay=5e-4)
        criterion = cross_entropy2d
        scheduler = MultiStepLR(optimizer, [100, 200, 400, 800, 3200], .1)
    elif args.model_name == 'CENet':
        model = CE_Net_(args).cuda()
        model = torch.nn.DataParallel(model)
        if args.optimizer == 'SGD':
            optimizer = SGD(model.parameters(),
                            .1,
                            weight_decay=5e-4,
                            momentum=.99)
            scheduler = MultiStepLR(optimizer, [100, 200, 400, 800, 3200], .1)
        elif args.optimizer == 'Adam':
            optimizer = Adam(model.parameters(), .001, weight_decay=5e-4)
            scheduler = MultiStepLR(optimizer, [400, 3200], .1)
        # criterion = cross_entropy2d
        criterion = DiceLoss()
        # scheduler = MultiStepLR(optimizer, [100, 200, 400, 800, 3200], .1)
    start_iter = 0
    if args.model_path is not None:
        if os.path.isfile(args.model_path):
            checkpoint = torch.load(args.model_path)
            model.load_state_dict(checkpoint["model_state"])
            optimizer.load_state_dict(checkpoint["optimizer_state"])
            scheduler.load_state_dict(checkpoint["scheduler_state"])
            start_iter = checkpoint["epoch"]
        else:
            print('Unable to load {}'.format(args.model_name))

    train_loader, valid_loader = get_loaders(args)

    try:
        os.mkdir('logs/')
    except:
        pass
    try:
        os.mkdir('results/')
    except:
        pass
    try:
        os.mkdir('results/' + args.model_name)
    except:
        pass
    writer = SummaryWriter(log_dir='logs/')

    best = -100.0
    i = start_iter
    flag = True

    running_metrics_val = Acc_Meter()
    val_loss_meter = averageMeter()
    time_meter = averageMeter()

    # while i <= args.niter and flag:
    while i <= 300000 and flag:
        for (images, labels) in train_loader:
            i += 1
            start_ts = time.time()
            scheduler.step()
            model.train()
            images = images.cuda()
            labels = labels.cuda()

            optimizer.zero_grad()
            outputs = model(images)

            loss = criterion(input=outputs, target=labels)

            loss.backward()
            optimizer.step()

            time_meter.update(time.time() - start_ts)

            # if (i + 1) % cfg["training"]["print_interval"] == 0:
            if (i + 1) % 50 == 0:
                fmt_str = "Iter [{:d}/{:d}]  Loss: {:.4f}  Time/Image: {:.4f}"
                print_str = fmt_str.format(
                    i + 1,
                    300000,
                    loss.item(),
                    time_meter.avg / args.batch_size,
                )

                print(print_str)
                # logger.info(print_str)
                # writer.add_scalar("loss/train_loss", loss.item(), i + 1)
                # time_meter.reset()

            # if (i + 1) % cfg["training"]["val_interval"] == 0 or (i + 1) == cfg["training"]["train_iters"]:
            if (i + 1) % 500 == 0 or (i + 1) == 300000:
                model.eval()
                with torch.no_grad():
                    for i_val, (images_val,
                                labels_val) in tqdm(enumerate(valid_loader)):
                        images_val = images_val.cuda()  # to(device)
                        labels_val = labels_val.cuda()  # to(device)

                        outputs = model(images_val)
                        # val_loss = loss_fn(input=outputs, target=labels_val)
                        val_loss = criterion(input=outputs, target=labels_val)

                        pred = outputs.data.max(1)[1].cpu().numpy()
                        gt = labels_val.data.cpu().numpy()

                        running_metrics_val.update(gt, pred)
                        val_loss_meter.update(val_loss.item())

                # writer.add_scalar("loss/val_loss", val_loss_meter.avg, i + 1)
                print("Iter %d Loss: %.4f" % (i + 1, val_loss_meter.avg))

                results = running_metrics_val.get_acc()
                for k, v in results.items():
                    writer.add_scalar(k, v, i + 1)
                print(results)

                val_loss_meter.reset()
                running_metrics_val.reset()

                if results['cls_acc'] >= best:
                    best = results['cls_acc']
                    state = {
                        "epoch": i + 1,
                        "model_state": model.state_dict(),
                        "optimizer_state": optimizer.state_dict(),
                        "scheduler_state": scheduler.state_dict(),
                        "best": best,
                    }
                    save_path = os.path.join(
                        "results/{}/results_{}_best_model.pkl".format(
                            args.model_name, i + 1), )
                    torch.save(state, save_path)

            if (i + 1) == 300000:
                flag = False
                break
    writer.close()
Пример #3
0
def main():
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    cudnn.benchmark = True

    start_epoch = args.start_epoch

    lr_decay_step = list(map(int, args.lr_decay_step.split(',')))

    # Data loading
    print_logger.info('=> Preparing data..')
    loader = import_module('data.' + args.dataset).Data(args)

    num_classes = 0
    if args.dataset in ['cifar10']:
        num_classes = 10

    model = eval(args.block_type + 'ResNet56_od')(
        groups=args.group_num,
        expansion=args.expansion,
        num_stu=args.num_stu,
        num_classes=num_classes).cuda()

    if len(args.gpu) > 1:
        device_id = []
        for i in range((len(args.gpu) + 1) // 2):
            device_id.append(i)
        model = torch.nn.DataParallel(model, device_ids=device_id)

    best_prec = 0.0

    if not model:
        print_logger.info("Model arch Error")
        return

    print_logger.info(model)

    # Define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=args.momentum,
                          weight_decay=args.weight_decay)
    scheduler = MultiStepLR(optimizer,
                            milestones=lr_decay_step,
                            gamma=args.lr_decay_factor)

    # Optionally resume from a checkpoint
    resume = args.resume
    if resume:
        print('=> Loading checkpoint {}'.format(resume))
        checkpoint = torch.load(resume)
        state_dict = checkpoint['state_dict']
        if args.adjust_ckpt:
            new_state_dict = {
                k.replace('module.', ''): v
                for k, v in state_dict.items()
            }
        else:
            new_state_dict = state_dict

        if args.start_epoch == 0:
            start_epoch = checkpoint['epoch']

        best_prec = checkpoint['best_prec']
        model.load_state_dict(new_state_dict)
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint['scheduler'])
        print('=> Continue from epoch {}...'.format(start_epoch))

    if args.test_only:
        test_prec = test(args, loader.loader_test, model)
        print('=> Test Prec@1: {:.2f}'.format(test_prec[0]))
        return

    record_top5 = 0.
    for epoch in range(start_epoch, args.epochs):

        scheduler.step(epoch)

        train_loss, train_prec = train(args, loader.loader_train, model,
                                       criterion, optimizer, epoch)
        test_prec = test(args, loader.loader_test, model, epoch)

        is_best = best_prec < test_prec[0]
        if is_best:
            record_top5 = test_prec[1]
        best_prec = max(test_prec[0], best_prec)

        state = {
            'state_dict': model.state_dict(),
            'test_prec': test_prec[0],
            'best_prec': best_prec,
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'epoch': epoch + 1
        }

        if epoch % args.save_freq == 0 or is_best:
            ckpt.save_model(state, epoch + 1, is_best)
        print_logger.info("=>Best accuracy {:.3f}, {:.3f}".format(
            best_prec, record_top5))
Пример #4
0
def train(train_loop_func, logger, args):
    if args.amp:
        amp_handle = amp.init(enabled=args.fp16)
    # Check that GPUs are actually available
    use_cuda = not args.no_cuda

    # Setup multi-GPU if necessary
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.N_gpu = torch.distributed.get_world_size()
    else:
        args.N_gpu = 1

    if args.seed is None:
        args.seed = np.random.randint(1e4)

    if args.distributed:
        args.seed = (args.seed + torch.distributed.get_rank()) % 2**32
    print("Using seed = {}".format(args.seed))
    torch.manual_seed(args.seed)
    np.random.seed(seed=args.seed)

    # Setup data, defaults
    dboxes = dboxes300_coco()
    encoder = Encoder(dboxes)
    cocoGt = get_coco_ground_truth(args)

    train_loader = get_train_loader(args, args.seed - 2**31)

    val_dataset = get_val_dataset(args)
    val_dataloader = get_val_dataloader(val_dataset, args)

    ssd300 = SSD300(backbone=args.backbone)
    args.learning_rate = args.learning_rate * args.N_gpu * (args.batch_size /
                                                            32)
    start_epoch = 0
    iteration = 0
    loss_func = Loss(dboxes)

    if use_cuda:
        ssd300.cuda()
        loss_func.cuda()

    if args.fp16 and not args.amp:
        ssd300 = network_to_half(ssd300)

    if args.distributed:
        ssd300 = DDP(ssd300)

    optimizer = torch.optim.SGD(tencent_trick(ssd300),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    scheduler = MultiStepLR(optimizer=optimizer,
                            milestones=args.multistep,
                            gamma=0.1)
    if args.fp16:
        if args.amp:
            optimizer = amp_handle.wrap_optimizer(optimizer)
        else:
            optimizer = FP16_Optimizer(optimizer, static_loss_scale=128.)
    if args.checkpoint is not None:
        if os.path.isfile(args.checkpoint):
            load_checkpoint(ssd300, args.checkpoint)
            checkpoint = torch.load(args.checkpoint,
                                    map_location=lambda storage, loc: storage.
                                    cuda(torch.cuda.current_device()))
            start_epoch = checkpoint['epoch']
            iteration = checkpoint['iteration']
            scheduler.load_state_dict(checkpoint['scheduler'])
            ssd300.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
        else:
            print('Provided checkpoint is not path to a file')
            return

    inv_map = {v: k for k, v in val_dataset.label_map.items()}

    total_time = 0

    if args.mode == 'evaluation':
        acc = evaluate(ssd300, val_dataloader, cocoGt, encoder, inv_map, args)
        if args.local_rank == 0:
            print('Model precision {} mAP'.format(acc))

        return
    mean, std = generate_mean_std(args)

    for epoch in range(start_epoch, args.epochs):
        start_epoch_time = time.time()
        scheduler.step()
        iteration = train_loop_func(ssd300, loss_func, epoch, optimizer,
                                    train_loader, val_dataloader, encoder,
                                    iteration, logger, args, mean, std)
        end_epoch_time = time.time() - start_epoch_time
        total_time += end_epoch_time

        if args.local_rank == 0:
            logger.update_epoch_time(epoch, end_epoch_time)

        if epoch in args.evaluation:
            acc = evaluate(ssd300, val_dataloader, cocoGt, encoder, inv_map,
                           args)

            if args.local_rank == 0:
                logger.update_epoch(epoch, acc)

        if args.save and args.local_rank == 0:
            print("saving model...")
            obj = {
                'epoch': epoch + 1,
                'iteration': iteration,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'label_map': val_dataset.label_info
            }
            if args.distributed:
                obj['model'] = ssd300.module.state_dict()
            else:
                obj['model'] = ssd300.state_dict()
            torch.save(obj, './models/epoch_{}.pt'.format(epoch))
        train_loader.reset()
    print('total training time: {}'.format(total_time))
Пример #5
0
def train(train_loop_func, logger, args):
    # Check that GPUs are actually available
    use_cuda = not args.no_cuda
    train_samples = 118287

    # Setup multi-GPU if necessary
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='smddp', init_method='env://')
        args.N_gpu = torch.distributed.get_world_size()
    else:
        args.N_gpu = 1

    if args.seed is None:
        args.seed = np.random.randint(1e4)

    if args.distributed:
        args.seed = (args.seed + torch.distributed.get_rank()) % 2**32
    print("Using seed = {}".format(args.seed))
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(seed=args.seed)


    # Setup data, defaults
    dboxes = dboxes300_coco()
    encoder = Encoder(dboxes)
    cocoGt = get_coco_ground_truth(args)

    train_loader = get_train_loader(args, args.seed - 2**31)

    val_dataset = get_val_dataset(args)
    val_dataloader = get_val_dataloader(val_dataset, args)

    ssd300 = SSD300(backbone=ResNet(args.backbone, args.backbone_path))
    args.learning_rate = args.learning_rate * args.N_gpu * (args.batch_size / 32)
    start_epoch = 0
    iteration = 0
    loss_func = Loss(dboxes)

    if use_cuda:
        ssd300.cuda()
        loss_func.cuda()

    optimizer = torch.optim.SGD(tencent_trick(ssd300), lr=args.learning_rate,
                                    momentum=args.momentum, weight_decay=args.weight_decay)
    scheduler = MultiStepLR(optimizer=optimizer, milestones=args.multistep, gamma=0.1)
    if args.amp:
        ssd300, optimizer = amp.initialize(ssd300, optimizer, opt_level='O2')

    if args.distributed:
        ssd300 = DDP(ssd300)

    if args.checkpoint is not None:
        if os.path.isfile(args.checkpoint):
            load_checkpoint(ssd300.module if args.distributed else ssd300, args.checkpoint)
            checkpoint = torch.load(args.checkpoint,
                                    map_location=lambda storage, loc: storage.cuda(torch.cuda.current_device()))
            start_epoch = checkpoint['epoch']
            iteration = checkpoint['iteration']
            scheduler.load_state_dict(checkpoint['scheduler'])
            optimizer.load_state_dict(checkpoint['optimizer'])
        else:
            print('Provided checkpoint is not path to a file')
            return

    inv_map = {v: k for k, v in val_dataset.label_map.items()}

    total_time = 0

    if args.mode == 'evaluation':
        acc = evaluate(ssd300, val_dataloader, cocoGt, encoder, inv_map, args)
        if args.local_rank == 0:
            print('Model precision {} mAP'.format(acc))

        return
    mean, std = generate_mean_std(args)

    for epoch in range(start_epoch, args.epochs):
        start_epoch_time = time.time()
        scheduler.step()
        iteration = train_loop_func(ssd300, loss_func, epoch, optimizer, train_loader, val_dataloader, encoder, iteration,
                                    logger, args, mean, std)
        end_epoch_time = time.time() - start_epoch_time
        total_time += end_epoch_time

        if torch.distributed.get_rank() == 0:
            throughput = train_samples / end_epoch_time
            logger.update_epoch_time(epoch, end_epoch_time)
            logger.update_throughput_speed(epoch, throughput)

        if epoch in args.evaluation:
            acc = evaluate(ssd300, val_dataloader, cocoGt, encoder, inv_map, args)

        if args.save and args.local_rank == 0:
            print("saving model...")
            obj = {'epoch': epoch + 1,
                   'iteration': iteration,
                   'optimizer': optimizer.state_dict(),
                   'scheduler': scheduler.state_dict(),
                   'label_map': val_dataset.label_info}
            if args.distributed:
                obj['model'] = ssd300.module.state_dict()
            else:
                obj['model'] = ssd300.state_dict()
            save_path = os.path.join(args.save, f'epoch_{epoch}.pt')
            torch.save(obj, save_path)
            logger.log('model path', save_path)
        train_loader.reset()

    if torch.distributed.get_rank() == 0:
        DLLogger.log((), { 'Total training time': '%.2f' % total_time + ' secs' })
        logger.log_summary()
model_state_path = os.path.join(
    Path.MODEL_DIR, 'densenet161_noisy_gpu_{}.tar'.format(args.gpu))

for epoch in range(MAX_ITERATIONS):
    model.train()
    for idx, ((x1, y1), (x2, y2)) in enumerate(zip(focus_dl, train_dl)):
        optimizer.zero_grad()
        x1, y1 = x1.float().cuda(), y1.float().cuda()
        x2, y2 = x2.float().cuda(), y2.float().cuda()
        x2, y2 = mixup_data(x2, y2, x2, y2, alpha=alpha)
        out2 = model(x2)
        out1 = model(x1)
        loss2 = criterion(out2, y2)
        loss1 = criterion(out1, y1)
        loss = w1[epoch] * loss1 + w2[epoch] * loss2
        loss.backward()
        optimizer.step()

    scheduler.step()

model_state = {
    "model_name": 'freesound',
    "optimizer": optimizer.state_dict(),
    "scheduler": scheduler.state_dict(),
    "state_dict": model.state_dict()
}
torch.save(model_state, model_state_path)

model_state = torch.load(model_state_path)
model.load_state_dict(model_state["state_dict"])
Пример #7
0
def main():
    if not torch.cuda.is_available():
        raise Exception("need gpu to train network!")

    torch.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    cudnn.benchmark = True
    cudnn.enabled = True

    logger = get_logger(__name__, Config.log)

    Config.gpus = torch.cuda.device_count()
    logger.info("use {} gpus".format(Config.gpus))
    config = {
        key: value
        for key, value in Config.__dict__.items() if not key.startswith("__")
    }
    logger.info(f"args: {config}")

    start_time = time.time()

    # dataset and dataloader
    logger.info("start loading data")

    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])
    train_dataset = ImageFolder(Config.train_dataset_path, train_transform)
    train_loader = DataLoader(
        train_dataset,
        batch_size=Config.batch_size,
        shuffle=True,
        num_workers=Config.num_workers,
        pin_memory=True,
    )
    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])
    val_dataset = ImageFolder(Config.val_dataset_path, val_transform)
    val_loader = DataLoader(
        val_dataset,
        batch_size=Config.batch_size,
        num_workers=Config.num_workers,
        pin_memory=True,
    )
    logger.info("finish loading data")

    # network
    net = ChannelDistillResNet1834(Config.num_classes, Config.dataset_type)
    net = nn.DataParallel(net).cuda()

    # loss and optimizer
    criterion = []
    for loss_item in Config.loss_list:
        loss_name = loss_item["loss_name"]
        loss_type = loss_item["loss_type"]
        if "kd" in loss_type:
            criterion.append(losses.__dict__[loss_name](loss_item["T"]).cuda())
        else:
            criterion.append(losses.__dict__[loss_name]().cuda())

    optimizer = SGD(net.parameters(),
                    lr=Config.lr,
                    momentum=0.9,
                    weight_decay=1e-4)
    scheduler = MultiStepLR(optimizer, milestones=[30, 60, 90], gamma=0.1)

    # only evaluate
    if Config.evaluate:
        # load best model
        if not os.path.isfile(Config.evaluate):
            raise Exception(
                f"{Config.evaluate} is not a file, please check it again")
        logger.info("start evaluating")
        logger.info(f"start resuming model from {Config.evaluate}")
        checkpoint = torch.load(Config.evaluate,
                                map_location=torch.device("cpu"))
        net.load_state_dict(checkpoint["model_state_dict"])
        prec1, prec5 = validate(val_loader, net)
        logger.info(
            f"epoch {checkpoint['epoch']:0>3d}, top1 acc: {prec1:.2f}%, top5 acc: {prec5:.2f}%"
        )
        return

    start_epoch = 1
    # resume training
    if os.path.exists(Config.resume):
        logger.info(f"start resuming model from {Config.resume}")
        checkpoint = torch.load(Config.resume,
                                map_location=torch.device("cpu"))
        start_epoch += checkpoint["epoch"]
        net.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
        logger.info(
            f"finish resuming model from {Config.resume}, epoch {checkpoint['epoch']}, "
            f"loss: {checkpoint['loss']:3f}, lr: {checkpoint['lr']:.6f}, "
            f"top1_acc: {checkpoint['acc']}%, loss {checkpoint['loss']}%")

    if not os.path.exists(Config.checkpoints):
        os.makedirs(Config.checkpoints)

    logger.info("start training")
    best_acc = 0.
    for epoch in range(start_epoch, Config.epochs + 1):
        prec1, prec5, loss = train(train_loader, net, criterion, optimizer,
                                   scheduler, epoch, logger)
        logger.info(
            f"train: epoch {epoch:0>3d}, top1 acc: {prec1:.2f}%, top5 acc: {prec5:.2f}%"
        )

        prec1, prec5 = validate(val_loader, net)
        logger.info(
            f"val: epoch {epoch:0>3d}, top1 acc: {prec1:.2f}%, top5 acc: {prec5:.2f}%"
        )

        # remember best prec@1 and save checkpoint
        torch.save(
            {
                "epoch": epoch,
                "acc": prec1,
                "loss": loss,
                "lr": scheduler.get_lr()[0],
                "model_state_dict": net.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
            }, os.path.join(Config.checkpoints, "latest.pth"))
        if prec1 > best_acc:
            shutil.copyfile(os.path.join(Config.checkpoints, "latest.pth"),
                            os.path.join(Config.checkpoints, "best.pth"))
            best_acc = prec1

    training_time = (time.time() - start_time) / 3600
    logger.info(
        f"finish training, best acc: {best_acc:.2f}%, total training time: {training_time:.2f} hours"
    )
Пример #8
0
def main():
    args = get_cli_args()
    print('Will save to ' + args.output_dir)
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    with open(os.path.join(args.output_dir, 'cmdline.txt'), 'w') as f:
        f.write(" ".join([
            "'" + a + "'" if (len(a) == 0 or a[0] != '-') else a
            for a in sys.argv
        ]))

    set_seed(args.seed)
    device = torch.device(args.device)
    writer = SummaryWriter(args.output_dir)

    train_dataset, test_dataset = create_tough_dataset(
        args,
        fold_nr=args.cvfold,
        n_folds=args.num_folds,
        seed=args.seed,
        exclude_Vertex_from_train=args.db_exclude_vertex,
        exclude_Prospeccts_from_train=args.db_exclude_prospeccts)
    logger.info('Train set size: %d, test set size: %d', len(train_dataset),
                len(test_dataset))

    # Create model and optimizer (or resume pre-existing)
    if args.resume != '':
        if args.resume == 'RESUME':
            args.resume = args.output_dir + '/model.pth.tar'
        model, optimizer, scheduler = resume(args, train_dataset, device)
    else:
        model = create_model(args, train_dataset, device)
        if args.input_normalization:
            model.set_input_scaler(
                estimate_scaler(args, train_dataset, nsamples=200))
        optimizer = create_optimizer(args, model)
        scheduler = MultiStepLR(optimizer,
                                milestones=args.lr_steps,
                                gamma=args.lr_decay)

    ############
    def train():
        model.train()

        loader = torch.utils.data.DataLoader(train_dataset,
                                             batch_size=args.batch_size //
                                             args.batch_parts,
                                             num_workers=args.nworkers,
                                             shuffle=True,
                                             drop_last=True,
                                             worker_init_fn=set_worker_seed)

        if logging.getLogger().getEffectiveLevel() > logging.DEBUG:
            loader = tqdm(loader, ncols=100)

        loss_buffer, loss_stabil_buffer, pos_dist_buffer, neg_dist_buffer = [], [], [], []
        t0 = time.time()

        for bidx, batch in enumerate(loader):
            if 0 < args.max_train_samples < bidx * args.batch_size // args.batch_parts:
                break
            t_loader = 1000 * (time.time() - t0)

            inputs = batch['inputs'].to(
                device)  # dimensions: batch_size x (4 or 2) x 24 x 24 x 24
            targets = batch['targets'].to(device)

            if bidx % args.batch_parts == 0:
                optimizer.zero_grad()
            t0 = time.time()

            outputs = model(inputs.view(-1, *inputs.shape[2:]))
            outputs = outputs.view(*inputs.shape[:2], -1)
            loss_joint, loss_match, loss_stabil, pos_dist, neg_dist = compute_loss(
                args, outputs, targets, True)
            loss_joint.backward()

            if bidx % args.batch_parts == args.batch_parts - 1:
                if args.batch_parts > 1:
                    for p in model.parameters():
                        p.grad.data.div_(args.batch_parts)
                optimizer.step()

            t_trainer = 1000 * (time.time() - t0)
            loss_buffer.append(loss_match.item())
            loss_stabil_buffer.append(loss_stabil.item(
            ) if isinstance(loss_stabil, torch.Tensor) else loss_stabil)
            pos_dist_buffer.extend(pos_dist.cpu().numpy().tolist())
            neg_dist_buffer.extend(neg_dist.cpu().numpy().tolist())
            logger.debug(
                'Batch loss %f, Loader time %f ms, Trainer time %f ms.',
                loss_buffer[-1], t_loader, t_trainer)
            t0 = time.time()

        ret = {
            'loss': np.mean(loss_buffer),
            'loss_stabil': np.mean(loss_stabil_buffer),
            'pos_dist': np.mean(pos_dist_buffer),
            'neg_dist': np.mean(neg_dist_buffer)
        }
        return ret

    ############
    def test():
        model.eval()

        loader = torch.utils.data.DataLoader(test_dataset,
                                             batch_size=args.batch_size //
                                             args.batch_parts,
                                             num_workers=args.nworkers,
                                             worker_init_fn=set_worker_seed)

        if logging.getLogger().getEffectiveLevel() > logging.DEBUG:
            loader = tqdm(loader, ncols=100)

        loss_buffer, loss_stabil_buffer, pos_dist_buffer, neg_dist_buffer = [], [], [], []

        with torch.no_grad():
            for bidx, batch in enumerate(loader):
                if 0 < args.max_test_samples < bidx * args.batch_size // args.batch_parts:
                    break
                inputs = batch['inputs'].to(device)
                targets = batch['targets'].to(device)

                outputs = model(inputs.view(-1, *inputs.shape[2:]))
                outputs = outputs.view(*inputs.shape[:2], -1)
                loss_joint, loss_match, loss_stabil, pos_dist, neg_dist = compute_loss(
                    args, outputs, targets, False)

                loss_buffer.append(loss_match.item())
                loss_stabil_buffer.append(loss_stabil.item(
                ) if isinstance(loss_stabil, torch.Tensor) else loss_stabil)
                pos_dist_buffer.extend(pos_dist.cpu().numpy().tolist())
                neg_dist_buffer.extend(neg_dist.cpu().numpy().tolist())

        return {
            'loss': np.mean(loss_buffer),
            'loss_stabil': np.mean(loss_stabil_buffer),
            'pos_dist': np.mean(pos_dist_buffer),
            'neg_dist': np.mean(neg_dist_buffer)
        }

    ############
    # Training loop
    for epoch in range(args.start_epoch, args.epochs):
        print(f'Epoch {epoch}/{args.epochs} ({args.output_dir}):')
        scheduler.step()

        train_stats = train()
        for k, v in train_stats.items():
            writer.add_scalar('train/' + k, v, epoch)
        print(
            f"-> Train distances: p {train_stats['pos_dist']}, n {train_stats['neg_dist']}, \tLoss: {train_stats['loss']}"
        )

        if (epoch + 1) % args.test_nth_epoch == 0 or epoch + 1 == args.epochs:
            test_stats = test()
            for k, v in test_stats.items():
                writer.add_scalar('test/' + k, v, epoch)
            print(
                f"-> Test distances: p {test_stats['pos_dist']}, n {test_stats['neg_dist']}, \tLoss: {test_stats['loss']}"
            )

        torch.save(
            {
                'epoch': epoch + 1,
                'args': args,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict()
            }, os.path.join(args.output_dir, 'model.pth.tar'))

        if math.isnan(train_stats['loss']):
            break
Пример #9
0
def main():
    # random.seed(0)
    # torch.manual_seed(0)
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False
    # np.random.seed(0)
    # torch.cuda.manual_seed(0)

    # set all hyperparameters
    network_name = 'WRN_40_2'
    num_epochs = 35
    batch_size = 1
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    n_retrain_epochs = 40 
    trials = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.85, 0.9]
    lr = 3e-4
    opt = "Adam"
    use_temp = False
    use_steps = False

    # set paths
    checkpointPath = './GNN_model/CIFAR10_checkpoints/CP__num_e_{}__retrain_e_{}__lr_{}__opt_{}__useTemp_{}__useSteps_{}__epoch_{}.pt'.format(num_epochs, n_retrain_epochs, lr, opt, use_temp, use_steps, '{}')    
    continue_train = False
    checkpointLoadPath = './GNN_model/CIFAR10_checkpoints/CP__num_e_{}__retrain_e_{}__lr_{}__opt_{}__useTemp_{}__useSteps_{}__epoch_{}.pt'.format(num_epochs, n_retrain_epochs, lr, opt, use_temp, use_steps, '20')

    # get GNN path
    info = networks_data.get(network_name)
    trained_model_path = info.get('trained_GNN_path').replace('.pt', '___num_e_{}__retrain_e_{}__lr_{}__opt_{}__useTemp_{}__useSteps_{}.pt'.format(num_epochs, n_retrain_epochs, lr, opt, use_temp, use_steps))

    # declare GNN model
    model = GNNPrunningNet(in_channels=6, out_channels=128).to(device)
    if opt == "Adam":
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    else:
        # lr = 0.1
        optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, nesterov=True, weight_decay=5e-4)
        scheduler = MultiStepLR(optimizer, milestones=[int(elem*num_epochs) for elem in [0.3, 0.6, 0.8]], gamma=0.2)
    crit = GNN_prune_loss

    # declate TensorBoard writer
    summary_path = '{}-num_e_{}__retrain_e_{}__lr_{}__opt_{}__useTemp_{}__useSteps_{}/training'.format(network_name, num_epochs, n_retrain_epochs, lr, opt, use_temp, use_steps)
    writer = SummaryWriter(summary_path)


    root            = info.get('root')
    net_graph_path  = info.get('graph_path')
    sd_path         = info.get('sd_path')
    net             = info.get('network')
    orig_net_loss   = info.get('orig_net_loss') 

    isWRN = (network_name == "WRN_40_2")
    train_dataset   = GraphDataset(root, network_name, isWRN, net_graph_path)
    train_loader    = DataLoader(train_dataset, batch_size=batch_size)

    orig_net = net().to(device)
    orig_net.load_state_dict(torch.load(sd_path, map_location=device))

    model.train()

    dataset_name = info.get('dataset_name')
    network_train_data = datasets_train.get(dataset_name)

    print("Start training")

    if continue_train == True:
        cp = torch.load(checkpointLoadPath, map_location=device)
        trained_epochs = cp['epoch'] + 1
        sd = cp['model_state_dict']
        model.load_state_dict(sd)
        op_sd = cp['optimizer_state_dict']
        optimizer.load_state_dict(op_sd)
    else:
        trained_epochs = 0

    loss_all = 0.0
    data_all = 0.0
    sparse_all = 0.0
    if use_temp == True:
        T = 1.0
        if trained_epochs > 0:
            T = np.power(2, np.floor(trained_epochs / int(num_epochs/3)))

    for epoch in range(trained_epochs, num_epochs):
        
        for data in train_loader:
            data = data.to(device)
            optimizer.zero_grad()
            output = model(data)

            if use_temp == True:
            # Use temperature
                nom = torch.pow((torch.exp(torch.tensor(T, device=device))), output)
                dom = torch.pow((torch.exp(torch.tensor(T, device=device))), output) + torch.pow((torch.exp(torch.tensor(T, device=device))), (1-output))
                output = nom/dom
                # continue as usual

            sparse_term, data_term, data_grad = crit(output, orig_net, orig_net_loss, network_name, network_train_data, device, gamma1=10, gamma2=0.1)

            if use_steps == True:
                if epoch % 3 == 0: # do 2 steps in data direction then 1 in sparsity
                    sparse_term.backward()
                else:
                    output.backward(data_grad)
            else:            
                sparse_term.backward(retain_graph=True)
                output.backward(data_grad)

            data_all += data.num_graphs * data_term.item()
            sparse_all += data.num_graphs * sparse_term.item()
            loss_all += data_all + sparse_all
            optimizer.step()
            
        print("epoch {}. total loss is: {}".format(epoch+1, (data_term.item() + sparse_term.item()) / len(train_dataset)))
        
        if opt != "Adam":
            scheduler.step()

        if use_temp == True:
        # increase temperature 3 times
            if (epoch+1) % int(num_epochs/3) == 0:
                T *= 2

        if epoch % 10 == 9:
            writer.add_scalars('Learning curve', {
            'loss data term': data_all/10,
            'loss sparsity term': sparse_all/10,
            'training loss': loss_all/10
            }, epoch+1)            

            # save checkpoint
            if opt == "Adam":
                torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss_all,
                }, checkpointPath.format(epoch+1))
            else:
                torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss_all,
                'scheduler_state_dict': scheduler.state_dict(),
                }, checkpointPath.format(epoch+1))

            loss_all = 0.0
            data_all = 0.0
            sparse_all = 0.0
            

    torch.save(model.state_dict(), trained_model_path)            

    print("Start evaluating")

    model.load_state_dict(torch.load(trained_model_path, map_location=device))

    model.eval()

    network_val_data = datasets_test.get(dataset_name)
    val_data_loader = torch.utils.data.DataLoader(network_val_data, batch_size=1024, shuffle=False, num_workers=8) 

    for trial, p_factor in enumerate(trials):
        with torch.no_grad():
            for data in train_loader:
                data = data.to(device)

                pred = model(data)

                prunedNet = getPrunedNet(pred, orig_net, network_name, prune_factor=p_factor).to(device)

        # Train the pruned network
        prunedNet.train()

        data_train_loader = torch.utils.data.DataLoader(network_train_data, batch_size=256, shuffle=False, num_workers=8) 
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(prunedNet.parameters(), lr=0.1, momentum=0.9, nesterov=True, weight_decay=5e-4)
        scheduler = MultiStepLR(optimizer, milestones=[int(elem*n_retrain_epochs) for elem in [0.3, 0.6, 0.8]], gamma=0.2)

        for epoch in range (n_retrain_epochs):
            for i, (images, labels) in enumerate(data_train_loader):
                images, labels = images.to(device), labels.to(device)
                optimizer.zero_grad()
                output = prunedNet(images)
                loss = criterion(output, labels)

                if i % 30 == 0:
                    print('Train - Epoch %d, Batch: %d, Loss: %f' % (epoch+1, i, loss.detach().cpu().item()))

                loss.backward()
                optimizer.step()

            scheduler.step()

        # Evaluate the pruned net
        with torch.no_grad():

            total_correct = 0
            cuda_time = 0.0            
            cpu_time = 0.0

            for i, (images, labels) in enumerate(val_data_loader):
                images, labels = images.to(device), labels.to(device)

                with torch.autograd.profiler.profile(use_cuda=True) as prof:
                    output = prunedNet(images)
                cuda_time += sum([item.cuda_time for item in prof.function_events])
                cpu_time += sum([item.cpu_time for item in prof.function_events])

                pred = output.detach().max(1)[1]
                total_correct += pred.eq(labels.view_as(pred)).sum()

            p_acc = float(total_correct) / len(network_val_data)
            p_num_params = gnp(prunedNet)
            p_cuda_time = cuda_time / len(network_val_data)
            p_cpu_time = cpu_time / len(network_val_data)

            print("The pruned network for prune factor {} accuracy is: {}".format(p_factor, p_acc))
            print("The pruned network number of parameters is: {}".format(p_num_params))
            print("The pruned network cuda time is: {}".format(p_cuda_time))
            print("The pruned network cpu time is: {}".format(p_cpu_time))

        # Evaluate the original net
        with torch.no_grad():

            total_correct = 0
            cuda_time = 0.0            
            cpu_time = 0.0
            
            for i, (images, labels) in enumerate(val_data_loader):
                images, labels = images.to(device), labels.to(device)

                with torch.autograd.profiler.profile(use_cuda=True) as prof:
                    output = orig_net(images)
                cuda_time += sum([item.cuda_time for item in prof.function_events])
                cpu_time += sum([item.cpu_time for item in prof.function_events])

                pred = output.detach().max(1)[1]
                total_correct += pred.eq(labels.view_as(pred)).sum()

            o_acc = float(total_correct) / len(network_val_data)
            o_num_params = gnp(orig_net)
            o_cuda_time = cuda_time / len(network_val_data)
            o_cpu_time = cpu_time / len(network_val_data)

            print("The original network accuracy is: {}".format(o_acc))
            print("The original network number of parameters is: {}".format(o_num_params))
            print("The original network cuda time is: {}".format(o_cuda_time))
            print("The original network cpu time is: {}".format(o_cpu_time))

        writer.add_scalars('Network accuracy', {
            'original': o_acc,
            'pruned': p_acc
            }, 100*p_factor)
        writer.add_scalars('Network number of parameters', {
            'original': o_num_params,
            'pruned': p_num_params
            }, 100*p_factor)
        writer.add_scalars('Network GPU time', {
            'original': o_cuda_time,
            'pruned': p_cuda_time
            }, 100*p_factor)
        writer.add_scalars('Network CPU time', {
            'original': o_cpu_time,
            'pruned': p_cpu_time
            }, 100*p_factor)

    writer.close()
Пример #10
0
class ImageNetAgent:
    def __init__(self, config, rank=-1):
        self.rank = rank
        self.config = config

        # Training environment
        if config['train']['mode'] == 'parallel':
            gpu_id = config['train']['gpus'][rank]
            self.device = "cuda:{}".format(gpu_id)
        else:
            self.device = config['train']['device'] if torch.cuda.is_available(
            ) else "cpu"

        # Dataset
        train_transform = T.Compose([
            T.RandomResizedCrop(
                (config['dataset']['size'], config['dataset']['size'])),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        valid_transform = T.Compose([
            T.Resize(256),
            T.CenterCrop(
                (config['dataset']['size'], config['dataset']['size'])),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        train_dataset = ImageFolder(config['dataset']['train']['root'],
                                    transform=train_transform)
        valid_dataset = ImageFolder(config['dataset']['valid']['root'],
                                    transform=valid_transform)

        # Dataloader
        if config['train']['mode'] == 'parallel':
            self.sampler = DistributedSampler(train_dataset)
            self.train_loader = DataLoader(
                train_dataset,
                sampler=self.sampler,
                batch_size=config['dataloader']['batch_size'],
                num_workers=config['dataloader']['num_workers'],
                pin_memory=True,
                shuffle=False)
        else:
            self.train_loader = DataLoader(
                train_dataset,
                batch_size=config['dataloader']['batch_size'],
                num_workers=config['dataloader']['num_workers'],
                pin_memory=True,
                shuffle=True)

        self.valid_loader = DataLoader(
            valid_dataset,
            batch_size=config['dataloader']['batch_size'],
            num_workers=config['dataloader']['num_workers'],
            pin_memory=True,
            shuffle=False)
        # Model
        if config['model']['name'] == "resnet18":
            model_cls = resnet18
        else:
            model_cls = get_model_cls(config['model']['name'])
        model = model_cls(**config['model']['kwargs'])
        if config['train']['mode'] == 'parallel':
            model = model.to(self.device)
            self.model = DDP(model, device_ids=[config['train']['gpus'][rank]])
            # checkpoint = torch.load("run/darknet53_dist/best.pth")
            # self.model.load_state_dict(checkpoint['model'])
        else:
            self.model = model.to(self.device)

        # Optimizer
        self.optimizer = optim.SGD(
            self.model.parameters(),
            lr=config['optimizer']['lr'],
            momentum=config['optimizer']['momentum'],
            weight_decay=config['optimizer']['weight_decay'])
        # Scheduler
        self.scheduler = MultiStepLR(
            self.optimizer,
            milestones=config['scheduler']['milestones'],
            gamma=config['scheduler']['gamma'])

        # Loss funciton
        self.criterion = nn.CrossEntropyLoss().to(self.device)

        # Tensorboard
        self.log_dir = osp.join(config['train']['log_dir'],
                                config['train']['exp_name'])
        if ((self.rank == 0 and config['train']['mode'] == 'parallel')
                or self.rank < 0):
            self.writer = SummaryWriter(logdir=self.log_dir)

        # Dynamic state
        self.current_epoch = -1
        self.current_loss = 10000

    def resume(self):
        checkpoint_path = osp.join(self.log_dir, 'best.pth')

        if self.config['train']['mode'] == 'parallel':
            master_gpu_id = self.config['train']['gpus'][0]
            map_location = {'cuda:{}'.format(master_gpu_id): self.device}
            checkpoint = torch.load(checkpoint_path, map_location=map_location)
        else:
            checkpoint = torch.load(checkpoint_path)

        # Load pretrained model
        self.model.load_state_dict(checkpoint['model'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.scheduler.load_state_dict(checkpoint['scheduler'])

        # Resume to training state
        self.current_loss = checkpoint['current_loss']
        self.current_epoch = checkpoint['current_epoch']
        print("Resume Training at epoch {}".format(self.current_epoch))

    def train(self):
        for epoch in range(self.current_epoch + 1,
                           self.config['train']['n_epochs']):
            self.current_epoch = epoch
            if self.config['train']['mode'] == 'parallel':
                self.sampler.set_epoch(self.current_epoch)
                self.train_one_epoch()
                self.validate()
                self.scheduler.step()
            else:
                self.train_one_epoch()
                self.validate()
                self.scheduler.step()

    def train_one_epoch(self):
        losses = []
        running_samples = 0
        running_corrects = 0
        self.model.train()
        loop = tqdm(
            self.train_loader,
            desc=
            (f"[{self.rank}] Train Epoch {self.current_epoch}/{self.config['train']['n_epochs']}"
             f"- LR: {self.optimizer.param_groups[0]['lr']:.3f}"),
            leave=True)
        for batch_idx, (imgs, labels) in enumerate(loop):
            imgs = imgs.to(self.device, non_blocking=True)
            labels = labels.to(self.device, non_blocking=True)

            outputs = self.model(imgs)
            loss = self.criterion(outputs, labels)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            preds = torch.max(outputs.data, 1)[1]
            corrects = float(torch.sum(preds == labels.data))
            running_samples += imgs.size(0)
            running_corrects += corrects

            losses.append(loss.item())
            loop.set_postfix(loss=sum(losses) / len(losses))

        if self.rank <= 0:
            epoch_loss = sum(losses) / len(losses)
            epoch_acc = running_corrects / running_samples
            self.writer.add_scalar("Train Loss", epoch_loss,
                                   self.current_epoch)
            self.writer.add_scalar("Train Acc", epoch_acc, self.current_epoch)
            print("Epoch {}:{}, Train Loss: {:.2f}, Train Acc: {:.2f}".format(
                self.current_epoch, self.config['train']['n_epochs'],
                epoch_loss, epoch_acc))

    def validate(self):
        losses = []
        running_samples = 0
        running_corrects = 0
        self.model.eval()
        loop = tqdm(
            self.valid_loader,
            desc=
            (f"Valid Epoch {self.current_epoch}/{self.config['train']['n_epochs']}"
             f"- LR: {self.optimizer.param_groups[0]['lr']:.3f}"),
            leave=True)
        with torch.no_grad():
            for batch_idx, (imgs, labels) in enumerate(loop):
                imgs = imgs.to(self.device, non_blocking=True)
                labels = labels.to(self.device, non_blocking=True)

                outputs = self.model(imgs)
                loss = self.criterion(outputs, labels)

                preds = torch.max(outputs.data, 1)[1]
                corrects = float(torch.sum(preds == labels.data))
                running_samples += imgs.size(0)
                running_corrects += corrects

                losses.append(loss.item())
                loop.set_postfix(loss=sum(losses) / len(losses))

        if self.rank <= 0:
            epoch_loss = sum(losses) / len(losses)
            epoch_acc = running_corrects / running_samples
            print("Epoch {}:{}, Valid Loss: {:.2f}, Valid Acc: {:.2f}".format(
                self.current_epoch, self.config['train']['n_epochs'],
                epoch_loss, epoch_acc))
            self.writer.add_scalar("Valid Loss", epoch_loss,
                                   self.current_epoch)
            self.writer.add_scalar("Valid Acc", epoch_acc, self.current_epoch)
            if epoch_loss < self.current_loss:
                self.current_loss = epoch_loss
                self._save_checkpoint()

    def finalize(self):
        pass

    def _save_checkpoint(self):
        checkpoints = {
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict(),
            'current_epoch': self.current_epoch,
            'current_loss': self.current_loss
        }
        checkpoint_path = osp.join(self.log_dir, 'best.pth')
        torch.save(checkpoints, checkpoint_path)
        print("Save checkpoint to '{}'".format(checkpoint_path))
def main():
    # Views the training images and displays the distance on anchor-negative and anchor-positive
    # print the experiment configuration
    print('\nCurrent time is \33[91m{}\33[0m'.format(str(time.asctime())))
    print('Parsed options: {}'.format(vars(args)))
    print('Number of Classes: {}\n'.format(len(train_dir.speakers)))

    # instantiate
    # model and initialize weights
    model = ExporingResNet(layers=[3, 4, 6, 3], num_classes=1211)

    assert os.path.isfile(args.resume)
    print('=> loading checkpoint {}'.format(args.resume))
    checkpoint = torch.load(args.resume)
    filtered = {
        k: v
        for k, v in checkpoint['state_dict'].items()
        if 'num_batches_tracked' not in k
    }
    model.load_state_dict(filtered)

    model.fc2 = nn.Linear(args.embedding_size, len(train_dir.speakers))
    # criterion = AngularSoftmax(in_feats=args.embedding_size,
    #                           num_classes=len(train_dir.classes))
    if args.cuda:
        model.cuda()

    fc2_params = list(map(id, model.fc2.parameters()))
    base_params = filter(lambda p: id(p) not in fc2_params, model.parameters())

    optimizer = torch.optim.SGD(
        [{
            'params': base_params
        }, {
            'params': model.fc2.parameters(),
            'lr': args.lr * 10
        }],
        lr=args.lr,
        momentum=args.momentum,
    )

    # optimizer2 = create_optimizer(model.fc2.parameters(), args.optimizer, **opt_kwargs)
    scheduler = MultiStepLR(optimizer, milestones=[8], gamma=0.1)

    start = 0
    if args.save_init:
        check_path = '{}/checkpoint_{}.pth'.format(args.check_path, start)
        torch.save(
            {
                'epoch': start,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict()
            }, check_path)
    # optionally resume from a checkpoint

    start += args.start_epoch
    print('Start epoch is : ' + str(start))
    end = start + args.epochs

    # pdb.set_trace()
    train_loader = torch.utils.data.DataLoader(train_dir,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)
    valid_loader = torch.utils.data.DataLoader(valid_dir,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(test_part,
                                              batch_size=args.test_batch_size,
                                              shuffle=False,
                                              **kwargs)

    criterion = nn.CrossEntropyLoss().cuda()

    for epoch in range(start, end):
        # pdb.set_trace()
        for param_group in optimizer.param_groups:
            print(
                '\n\33[1;34m Current \'{}\' learning rate is {}.\33[0m'.format(
                    args.optimizer, param_group['lr']))

        train(train_loader, model, optimizer, criterion, scheduler, epoch)
        test(test_loader, valid_loader, model, epoch)

        scheduler.step()
        # break

    writer.close()
    end = time.time()
    log_2 = 'Using {} sec for epoch {}'.format(start - end, i + 1)
    cu.write_log(log_dir, log_2)

    ## validation

    vali_loss, recall, precision, accuracy = validation(
        model, validation_dataloader, loss, gpu, optimizer, validation_dataset)
    log_3 = '----------------Recall for epoch {} is : {}'.format(i + 1, recall)
    log_4 = '----------------Precision for epoch {} is : {}'.format(
        i + 1, precision)
    log_5 = '----------------Accuracy for epoch {} is {}'.format(
        i + 1, accuracy)
    log_6 = '----------------Validation loss for epoch {} is {}'.format(
        i + 1, vali_loss)

    cu.write_log(log_dir, log_3)
    cu.write_log(log_dir, log_4)
    cu.write_log(log_dir, log_5)
    cu.write_log(log_dir, log_6)
    if vali_loss < vali_loss_p:
        vali_loss_p = vali_loss  # give loss new value
        checkpoint = {
            'model_state': model.state_dict(),
            'criterion_state': loss.state_dict(),
            'optimizer_state': optimizer.state_dict(),
            'scheduler_state': scheduler.state_dict(),
            'epochs': i + 1
        }
        torch.save(checkpoint, checkpoint_dir + 'model' + '.pth')
Пример #13
0
class Trainner(object):
    def __init__(self, opt, saver, summary):
        self.opt = opt
        self.saver = saver
        self.summary = summary
        self.global_steps = 0
        self.setup_models()
        print('initialize trainner')

    def setup_models(self):
        self.criteriasMSE = nn.MSELoss()
        self.model = EDSeg(self.opt).cuda(self.opt.device)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.opt.lr)
        self.schedule = MultiStepLR(self.optimizer, milestones=[1600], gamma=self.opt.gamma)

        if self.opt.device_count > 1:
            self.model = nn.DataParallel(self.model, device_ids=self.opt.devices)

    def train_iter(self, images, labels):
        self.model.train()
        output = self.model(images)
        loss = self.criteriasMSE(output, images)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        self.schedule.step()

        if self.global_steps % self.opt.v_freq == 0:
            lr = self.optimizer.param_groups[0]['lr']
            self.summary.add_scalar('loss', loss.item(), self.global_steps)
            self.summary.add_scalar('lr', lr, self.global_steps)
            self.summary.train_image(images, output, labels, self.global_steps)

    def train_epoch(self, dataloader):
        iterator = tqdm(dataloader,
                        leave=True,
                        dynamic_ncols=True)
        self.dataset_len = len(dataloader)
        for i, data in enumerate(iterator):
            iterator.set_description(f'Epoch[{self.epoch}/{self.opt.epochs}|{self.global_steps}]')
            self.global_steps = self.epoch * self.dataset_len + i

            if isinstance(data, tuple):
                images = data[0]
                labels = data[1]
            else:
                labels = None
                images = data

            images = images.to(self.opt.device)
            labels = labels.to(self.opt.device) if labels else None

            self.train_iter(images, labels)

    def train(self, train_dataloader, valid_dataloader):
        self.epoch = 0
        while self.epoch < self.opt.epochs:
            self.train_epoch(train_dataloader)
            self.validate(valid_dataloader)
            self.epoch += 1

    def validate(self, dataloader):
        self.model.eval()
        errs = []
        output = None

        for data in dataloader:
            if isinstance(data, tuple):
                images = data[0]
                labels = data[1]
            else:
                labels = None
                images = data

            images = images.to(self.opt.device)
            labels = labels.to(self.opt.device) if labels else None
            
            output = self.model(images)
            err = self.criteriasMSE(output, images)
            errs.append(err)

        mean_err = torch.tensor(errs).mean()

        self.summary.add_scalar('validate_error', mean_err, self.global_steps)
        self.summary.val_image(images, output, labels, self.global_steps)
        self.save_checkpoint(mean_err)

    def save_checkpoint(self, err):
        state = {'epoch': self.epoch,
                'err': err,
                'model': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'schedule': self.schedule.state_dict()}
        self.saver.save_checkpoint(state)

    def load_checkpoint(self, state):
        self.model.load_state_dict(state['model'])
        self.optimizer.load_state_dict(state['optimizer'])
        self.schedule.load_state_dict(state['schedule'])
        self.epoch = state['epoch']

    def resume(self):
        if self.opt.resume:
            if self.opt.resume_best:
                state = self.saver.load_best()
            elif self.opt.resume_latest:
                state = self.saver.load_latest()
            elif self.opt.resume_epoch is not None:
                state = self.saver.load_epoch(self.opt.resume_epoch)
            else:
                raise RuntimeError('resume settings error, please check your config file.')
            self.load_checkpoint(state)
        else:
            print('resume not enabled, pass')
Пример #14
0
class Processor():
    """Processor for Skeleton-based Action Recgnition"""
    def __init__(self, arg):
        self.arg = arg
        self.save_arg()
        self._init_dist_pytorch(backend='nccl',
                                world_size=torch.cuda.device_count())
        self.load_model()
        self.load_param_groups()
        self.load_optimizer()
        self.load_lr_scheduler()
        self.load_data()

        self.global_step = 0
        self.lr = self.arg.base_lr
        self.best_acc = 0
        self.best_acc_epoch = 0

        if self.arg.half:
            self.print_log('*************************************')
            self.print_log('*** Using Half Precision Training ***')
            self.print_log('*************************************')
            self.model, self.optimizer = apex.amp.initialize(
                self.model,
                self.optimizer,
                opt_level=f'O{self.arg.amp_opt_level}')
            if self.arg.amp_opt_level != 1:
                self.print_log(
                    '[WARN] nn.DataParallel is not yet supported by amp_opt_level != "O1"'
                )
        self.runner()

    def _init_dist_pytorch(self, backend, **kwargs):
        rank = int(os.environ['RANK'])
        num_gpus = torch.cuda.device_count()
        torch.cuda.set_device(rank % num_gpus)
        dist.init_process_group(backend=backend, **kwargs)

    def load_model(self):
        Model = import_class(self.arg.model)

        # Copy model file and main
        shutil.copy2(inspect.getfile(Model), self.arg.work_dir)
        shutil.copy2(os.path.join('.', __file__), self.arg.work_dir)

        self._model = Model(**self.arg.model_args).cuda()
        self.loss = nn.CrossEntropyLoss().cuda()
        self.print_log(
            f'Model total number of params: {count_params(self._model)}')

        if self.arg.weights:
            try:
                self.global_step = int(arg.weights[:-3].split('-')[-1])
            except:
                print('Cannot parse global_step from model weights filename')
                self.global_step = 0

            self.print_log(f'Loading weights from {self.arg.weights}')
            if '.pkl' in self.arg.weights:
                with open(self.arg.weights, 'r') as f:
                    weights = pickle.load(f)
            elif '.pth' in self.arg.weights:
                weights = torch.load(self.arg.weights)["state_dict"]
                weights = OrderedDict([[k.split('network.')[-1],
                                        v.cuda()] for k, v in weights.items()])
            else:
                weights = torch.load(self.arg.weights)
                weights = OrderedDict([[k.split('module.')[-1],
                                        v.cuda()] for k, v in weights.items()])

            for w in self.arg.ignore_weights:
                if weights.pop(w, None) is not None:
                    self.print_log(f'Sucessfully Remove Weights: {w}')
                else:
                    self.print_log(f'Can Not Remove Weights: {w}')

            if '.pth' in self.arg.weights:
                try:
                    self._model.load_state_dict(weights)
                except:
                    state = self._model.state_dict()
                    diff = list(
                        set(state.keys()).difference(set(weights.keys())))
                    self.print_log('Can not find these weights:')
                    for d in diff:
                        self.print_log('  ' + d)
                    state.update(weights)
                    self._model.load_state_dict(state)
            elif self.arg.weights.endswith(".pt") or self.arg.weights.endswith(
                    ".pkl"):
                model_params = self._model.state_dict()
                weights['data_bn.weight'] = model_params['data_bn.weight']
                weights['data_bn.bias'] = model_params['data_bn.bias']
                weights['data_bn.running_mean'] = model_params[
                    'data_bn.running_mean']
                weights['data_bn.running_var'] = model_params[
                    'data_bn.running_var']
                weights['fc.weight'] = model_params['fc.weight']
                weights['fc.bias'] = model_params['fc.bias']
                weights['gcn3d1.gcn3d.0.gcn3d.1.A_res'] = model_params[
                    'gcn3d1.gcn3d.0.gcn3d.1.A_res']
                weights['gcn3d1.gcn3d.1.gcn3d.1.A_res'] = model_params[
                    'gcn3d1.gcn3d.1.gcn3d.1.A_res']
                weights['sgcn1.0.A_res'] = model_params['sgcn1.0.A_res']
                weights['gcn3d2.gcn3d.0.gcn3d.1.A_res'] = model_params[
                    'gcn3d2.gcn3d.0.gcn3d.1.A_res']
                weights['gcn3d2.gcn3d.1.gcn3d.1.A_res'] = model_params[
                    'gcn3d2.gcn3d.1.gcn3d.1.A_res']
                weights['sgcn2.0.A_res'] = model_params['sgcn2.0.A_res']
                weights['gcn3d3.gcn3d.0.gcn3d.1.A_res'] = model_params[
                    'gcn3d3.gcn3d.0.gcn3d.1.A_res']
                weights['gcn3d3.gcn3d.1.gcn3d.1.A_res'] = model_params[
                    'gcn3d3.gcn3d.1.gcn3d.1.A_res']
                weights['sgcn3.0.A_res'] = model_params['sgcn3.0.A_res']

                weights[
                    'gcn3d1.gcn3d.0.gcn3d.1.mlp.layers.0.weight'] = model_params[
                        'gcn3d1.gcn3d.0.gcn3d.1.mlp.layers.0.weight']
                weights[
                    'gcn3d1.gcn3d.1.gcn3d.1.mlp.layers.0.weight'] = model_params[
                        'gcn3d1.gcn3d.1.gcn3d.1.mlp.layers.0.weight']
                weights['sgcn1.0.mlp.layers.0.weight'] = model_params[
                    'sgcn1.0.mlp.layers.0.weight']
                weights[
                    'gcn3d1.gcn3d.0.gcn3d.1.mlp.layers.0.bias'] = model_params[
                        'gcn3d1.gcn3d.0.gcn3d.1.mlp.layers.0.bias']
                weights[
                    'gcn3d1.gcn3d.0.gcn3d.1.mlp.layers.1.weight'] = model_params[
                        'gcn3d1.gcn3d.0.gcn3d.1.mlp.layers.1.weight']
                weights[
                    'gcn3d1.gcn3d.0.gcn3d.1.mlp.layers.1.bias'] = model_params[
                        'gcn3d1.gcn3d.0.gcn3d.1.mlp.layers.1.bias']
                weights[
                    'gcn3d1.gcn3d.0.gcn3d.1.mlp.layers.1.running_mean'] = model_params[
                        'gcn3d1.gcn3d.0.gcn3d.1.mlp.layers.1.running_mean']
                weights[
                    'gcn3d1.gcn3d.0.gcn3d.1.mlp.layers.1.running_var'] = model_params[
                        'gcn3d1.gcn3d.0.gcn3d.1.mlp.layers.1.running_var']
                weights['gcn3d1.gcn3d.0.out_conv.weight'] = model_params[
                    'gcn3d1.gcn3d.0.out_conv.weight']
                weights[
                    'gcn3d1.gcn3d.1.gcn3d.1.mlp.layers.0.bias'] = model_params[
                        'gcn3d1.gcn3d.1.gcn3d.1.mlp.layers.0.bias']
                weights[
                    'gcn3d1.gcn3d.1.gcn3d.1.mlp.layers.1.weight'] = model_params[
                        'gcn3d1.gcn3d.1.gcn3d.1.mlp.layers.1.weight']
                weights[
                    'gcn3d1.gcn3d.1.gcn3d.1.mlp.layers.1.bias'] = model_params[
                        'gcn3d1.gcn3d.1.gcn3d.1.mlp.layers.1.bias']
                weights[
                    'gcn3d1.gcn3d.1.gcn3d.1.mlp.layers.1.running_mean'] = model_params[
                        'gcn3d1.gcn3d.1.gcn3d.1.mlp.layers.1.running_mean']
                weights[
                    'gcn3d1.gcn3d.1.gcn3d.1.mlp.layers.1.running_var'] = model_params[
                        'gcn3d1.gcn3d.1.gcn3d.1.mlp.layers.1.running_var']
                weights['gcn3d1.gcn3d.1.out_conv.weight'] = model_params[
                    'gcn3d1.gcn3d.1.out_conv.weight']
                model_params.update(weights)
                self._model.load_state_dict(model_params, strict=False)
            else:
                raise "Support *.pth or *.pkl or *.pt pretrain"

        print(self.arg.center_loss)
        self._model_full = msg3d_with_loss(self._model, self.loss,
                                           self.arg.center_loss)
        rank = int(os.environ['RANK'])
        # self._model.to(rank)
        self._model_full.to(rank)
        self.model = MMDistributedDataParallel(self._model_full.cuda())

    def load_param_groups(self):
        """
        Template function for setting different learning behaviour
        (e.g. LR, weight decay) of different groups of parameters
        """
        self.param_groups = defaultdict(list)

        for name, params in self.model.named_parameters():
            self.param_groups['other'].append(params)

        self.optim_param_groups = {
            'other': {
                'params': self.param_groups['other']
            }
        }

    def load_optimizer(self):
        params = list(self.optim_param_groups.values())
        if self.arg.optimizer == 'SGD':
            self.optimizer = optim.SGD(params,
                                       lr=self.arg.base_lr,
                                       momentum=0.9,
                                       nesterov=self.arg.nesterov,
                                       weight_decay=self.arg.weight_decay)
        elif self.arg.optimizer == 'Adam':
            self.optimizer = optim.Adam(params,
                                        lr=self.arg.base_lr,
                                        weight_decay=self.arg.weight_decay)
        else:
            raise ValueError('Unsupported optimizer: {}'.format(
                self.arg.optimizer))

        # Load optimizer states if any
        if self.arg.checkpoint is not None:
            self.print_log(
                f'Loading optimizer states from: {self.arg.checkpoint}')
            self.optimizer.load_state_dict(
                torch.load(self.arg.checkpoint)['optimizer_states'])
            current_lr = self.optimizer.param_groups[0]['lr']
            self.print_log(f'Starting LR: {current_lr}')
            self.print_log(
                f'Starting WD1: {self.optimizer.param_groups[0]["weight_decay"]}'
            )
            if len(self.optimizer.param_groups) >= 2:
                self.print_log(
                    f'Starting WD2: {self.optimizer.param_groups[1]["weight_decay"]}'
                )

    def load_lr_scheduler(self):
        self.lr_scheduler = MultiStepLR(self.optimizer,
                                        milestones=self.arg.step,
                                        gamma=0.1)
        if self.arg.checkpoint is not None:
            scheduler_states = torch.load(
                self.arg.checkpoint)['lr_scheduler_states']
            self.print_log(
                f'Loading LR scheduler states from: {self.arg.checkpoint}')
            self.lr_scheduler.load_state_dict(scheduler_states)
            self.print_log(
                f'Starting last epoch: {scheduler_states["last_epoch"]}')
            self.print_log(
                f'Loaded milestones: {scheduler_states["last_epoch"]}')

    def load_data(self):
        Feeder = import_class(self.arg.feeder)
        self.data_loader = dict()

        def worker_seed_fn(worker_id):
            # give workers different seeds
            return init_seed(self.arg.seed + worker_id + 1)

        rank = int(os.environ['RANK'])
        world_size = torch.cuda.device_count()
        if self.arg.phase == 'train':
            dataset_train = Feeder(**self.arg.train_feeder_args)
            sampler_train = DistributedSampler(dataset_train,
                                               world_size,
                                               rank,
                                               shuffle=True)
            self.data_loader['train'] = torch.utils.data.DataLoader(
                dataset=dataset_train,
                batch_size=self.arg.batch_size // world_size,
                sampler=sampler_train,
                shuffle=False,
                num_workers=self.arg.num_worker // world_size,
                drop_last=True,
                worker_init_fn=worker_seed_fn)

        dataset_test = Feeder(**self.arg.test_feeder_args)
        # sampler_test = DistributedSampler(dataset_test, world_size, rank, shuffle=False)
        self.data_loader['test'] = torch.utils.data.DataLoader(
            dataset=dataset_test,
            batch_size=self.arg.test_batch_size // world_size,
            # sampler=sampler_test,
            shuffle=False,
            num_workers=self.arg.num_worker // world_size,
            drop_last=False,
            worker_init_fn=worker_seed_fn)

    def save_arg(self):
        # save arg
        arg_dict = vars(self.arg)
        if not os.path.exists(self.arg.work_dir):
            os.makedirs(self.arg.work_dir, exist_ok=True)
        with open(os.path.join(self.arg.work_dir, 'config.yaml'), 'w') as f:
            yaml.dump(arg_dict, f)

    def print_time(self):
        localtime = time.asctime(time.localtime(time.time()))
        self.print_log(f'Local current time: {localtime}')

    def print_log(self, s, print_time=True):
        if print_time:
            localtime = time.asctime(time.localtime(time.time()))
            s = f'[ {localtime} ] {s}'
        print(s)
        if self.arg.print_log:
            with open(os.path.join(self.arg.work_dir, 'log.txt'), 'a') as f:
                print(s, file=f)

    def record_time(self):
        self.cur_time = time.time()
        return self.cur_time

    def split_time(self):
        split_time = time.time() - self.cur_time
        self.record_time()
        return split_time

    def save_states(self, epoch, states, out_folder, out_name):
        out_folder_path = os.path.join(self.arg.work_dir, out_folder)
        out_path = os.path.join(out_folder_path, out_name)
        os.makedirs(out_folder_path, exist_ok=True)
        torch.save(states, out_path)

    def save_checkpoint(self, epoch, out_folder='checkpoints'):
        state_dict = {
            'epoch': epoch,
            'optimizer_states': self.optimizer.state_dict(),
            'lr_scheduler_states': self.lr_scheduler.state_dict(),
        }

        checkpoint_name = f'checkpoint-{epoch}-fwbz{self.arg.forward_batch_size}-{int(self.global_step)}.pt'
        self.save_states(epoch, state_dict, out_folder, checkpoint_name)

    def save_weights(self, epoch, out_folder='weights'):
        state_dict = self.model.state_dict()
        weights = OrderedDict([[k.split('module.')[-1],
                                v.cpu()] for k, v in state_dict.items()])

        weights_name = f'weights-{epoch}-{int(self.global_step)}.pt'
        self.save_states(epoch, weights, out_folder, weights_name)

    def runner(self):
        def parse_losses(losses):
            log_vars = OrderedDict()
            for loss_name, loss_value in losses.items():
                if isinstance(loss_value, torch.Tensor):
                    log_vars[loss_name] = loss_value.mean()
                elif isinstance(loss_value, list):
                    log_vars[loss_name] = sum(_loss.mean()
                                              for _loss in loss_value)
                else:
                    raise TypeError(
                        '{} is not a tensor or list of tensors'.format(
                            loss_name))

            loss = sum(_value for _key, _value in log_vars.items()
                       if 'loss' in _key)

            log_vars['loss'] = loss
            for name in log_vars:
                log_vars[name] = log_vars[name].item()

            return loss, log_vars

        def batch_processor(model, data, train_mode):
            losses = model(**data)
            # losses = model(data)
            loss, log_vars = parse_losses(losses)
            outputs = dict(loss=loss,
                           log_vars=log_vars,
                           num_samples=len(data['batchdata'].data))
            return outputs

        self.runner = Runner(self.model, batch_processor, self.optimizer,
                             self.arg.work_dir)
        optimizer_config = DistOptimizerHook(
            grad_clip=dict(max_norm=20, norm_type=2))
        if not "policy" in self.arg.policy:
            lr_config = dict(policy='step', step=self.arg.step)
        else:
            lr_config = dict(**self.arg.policy)
        checkpoint_config = dict(interval=5)
        log_config = dict(interval=20,
                          hooks=[
                              dict(type='TextLoggerHook'),
                              dict(type='TensorboardLoggerHook')
                          ])
        self.runner.register_training_hooks(lr_config, optimizer_config,
                                            checkpoint_config, log_config)
        self.runner.register_hook(DistSamplerSeedHook())
        Feeder = import_class(self.arg.feeder)
        self.runner.register_hook(
            DistEvalTopKAccuracyHook(Feeder(**self.arg.test_feeder_args),
                                     interval=self.arg.test_interval,
                                     k=(1, 5)))

    def eval(self,
             epoch,
             save_score=False,
             loader_name=['test'],
             wrong_file=None,
             result_file=None):
        # Skip evaluation if too early
        if epoch + 1 < self.arg.eval_start:
            return

        if wrong_file is not None:
            f_w = open(wrong_file, 'w')
        if result_file is not None:
            f_r = open(result_file, 'w')
        with torch.no_grad():
            self.model = self.model.cuda()
            self.model.eval()
            self.print_log(f'Eval epoch: {epoch + 1}')
            for ln in loader_name:
                loss_values = []
                score_batches = []
                step = 0
                process = tqdm(self.data_loader[ln], dynamic_ncols=True)
                for batch_idx, (data, label, index) in enumerate(process):
                    data = data.float().cuda()
                    label = label.long().cuda()
                    output = self.model(data)
                    if isinstance(output, tuple):
                        output, l1 = output
                        l1 = l1.mean()
                    else:
                        l1 = 0
                    loss = self.loss(output, label)
                    score_batches.append(output.data.cpu().numpy())
                    loss_values.append(loss.item())

                    _, predict_label = torch.max(output.data, 1)
                    step += 1

                    if wrong_file is not None or result_file is not None:
                        predict = list(predict_label.cpu().numpy())
                        true = list(label.data.cpu().numpy())
                        for i, x in enumerate(predict):
                            if result_file is not None:
                                f_r.write(str(x) + ',' + str(true[i]) + '\n')
                            if x != true[i] and wrong_file is not None:
                                f_w.write(
                                    str(index[i]) + ',' + str(x) + ',' +
                                    str(true[i]) + '\n')

            score = np.concatenate(score_batches)
            loss = np.mean(loss_values)
            accuracy = self.data_loader[ln].dataset.top_k(score, 1)
            if accuracy > self.best_acc:
                self.best_acc = accuracy
                self.best_acc_epoch = epoch + 1

            print('Accuracy: ', accuracy, ' model: ', self.arg.work_dir)
            if self.arg.phase == 'train' and not self.arg.debug:
                self.val_writer.add_scalar('loss', loss, self.global_step)
                self.val_writer.add_scalar('loss_l1', l1, self.global_step)
                self.val_writer.add_scalar('acc', accuracy, self.global_step)

            score_dict = dict(
                zip(self.data_loader[ln].dataset.sample_name, score))
            self.print_log(
                f'\tMean {ln} loss of {len(self.data_loader[ln])} batches: {np.mean(loss_values)}.'
            )
            for k in self.arg.show_topk:
                self.print_log(
                    f'\tTop {k}: {100 * self.data_loader[ln].dataset.top_k(score, k):.2f}%'
                )

            if save_score:
                with open(
                        '{}/epoch{}_{}_score.pkl'.format(
                            self.arg.work_dir, epoch + 1, ln), 'wb') as f:
                    pickle.dump(score_dict, f)

        # Empty cache after evaluation
        torch.cuda.empty_cache()

    def start(self):
        if self.arg.phase == 'train':
            self.runner.run([self.data_loader['train']],
                            workflow=[('train', 1)],
                            max_epochs=self.arg.num_epoch)
        elif self.arg.phase == 'test':
            wf = rf = None
            if self.arg.weights is None:
                raise ValueError('Please appoint --weights.')

            self.print_log(f'Model:   {self.arg.model}')
            self.print_log(f'Weights: {self.arg.weights}')

            self.eval(epoch=0,
                      save_score=self.arg.save_score,
                      loader_name=['test'],
                      wrong_file=wf,
                      result_file=rf)

            self.print_log('Done.\n')
Пример #15
0
def main_worker(train_loader, val_loader, num_classes, args, cifar=False):
    global best_acc1

    scale_lr_and_momentum(args, cifar=cifar)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    norm_kwargs = {'mode': args.norm_mode,
                   'alpha_fwd': args.afwd,
                   'alpha_bkw': args.abkw,
                   'ecm': args.ecm,
                   'gn_num_groups': args.gn_num_groups}
    model_kwargs = {'num_classes': num_classes,
                    'norm_layer': norm_layer,
                    'norm_kwargs': norm_kwargs,
                    'cifar': cifar,
                    'kernel_size': 3 if cifar else 7,
                    'stride': 1 if cifar else 2,
                    'padding': 1 if cifar else 3,
                    'inplanes': 16 if cifar else 64}
    if cifar:
        model_kwargs['depth'] = args.depth
        args.arch = 'resnetD'

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True,
                                           **model_kwargs).to(device)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch](**model_kwargs).to(device)

    print(model)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().to(device)

    optimizer = torch.optim.SGD(get_parameter_groups(model, cifar=cifar),
                                args.lr, momentum=args.momentum,
                                weight_decay=args.weight_decay)

    scheduler = MultiStepLR(optimizer,
                            milestones=args.lr_milestones,
                            gamma=args.lr_multiplier)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_acc1 = checkpoint['best_acc1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cudnn.benchmark = False if args.seed else True

    if args.evaluate:
        validate(val_loader, model, criterion, device, args)
        return

    for epoch in range(args.start_epoch, args.epochs):
        if epoch: scheduler.step()

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, device, args)

        # evaluate on validation set
        acc1 = validate(val_loader, model, criterion, device, args)

        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)

        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_acc1': best_acc1,
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
        }, is_best, args)
Пример #16
0
def train(train_loop_func, logger, args):
    # Check that GPUs are actually available
    use_cuda = not args.no_cuda

    # Setup multi-GPU if necessary
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.N_gpu = torch.distributed.get_world_size()
    else:
        args.N_gpu = 1

    if args.seed is None:
        args.seed = np.random.randint(1e4)

    if args.distributed:
        args.seed = (args.seed + torch.distributed.get_rank()) % 2**32
    print("Using seed = {}".format(args.seed))
    torch.manual_seed(args.seed)
    np.random.seed(seed=args.seed)

    torch.multiprocessing.set_sharing_strategy('file_system')

    # Setup data, defaults
    dboxes = dboxes300_coco()
    encoder = Encoder(dboxes)
    cocoGt = get_coco_ground_truth(args)
    #82783
    # train_loader = get_train_loader(args, args.seed - 2**31, 118287)

    # target_loader = get_target_loader(args, args.seed - 2**31, 118287)

    train_loader = get_train_loader(args, args.seed - 2**31, 5000)

    target_loader = get_target_loader(args, args.seed - 2**31, 5000)

    val_dataset = get_val_dataset(args)
    val_dataloader = get_val_dataloader(val_dataset, args)

    ssd300 = DASSD300(backbone=ResNet(args.backbone, args.backbone_path))
    # ?????args.learning_rate = args.learning_rate * args.N_gpu * ((args.batch_size + args.batch_size // 2) / 32)
    args.learning_rate = args.learning_rate * args.N_gpu * (
        (args.batch_size + args.batch_size) / 32)
    start_epoch = 0
    iteration = 0
    loss_func = DALoss(dboxes)
    da_loss_func = ImageLevelAdaptationLoss()

    if use_cuda:
        ssd300.cuda()
        loss_func.cuda()
        da_loss_func.cuda()

    optimizer = torch.optim.SGD(tencent_trick(ssd300),
                                lr=args.learning_rate,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)
    scheduler = MultiStepLR(optimizer=optimizer,
                            milestones=args.multistep,
                            gamma=0.1)
    if args.amp:
        ssd300, optimizer = amp.initialize(ssd300, optimizer, opt_level='O2')

    if args.distributed:
        ssd300 = DDP(ssd300)

    if args.checkpoint is not None:
        if os.path.isfile(args.checkpoint):
            load_checkpoint(ssd300.module if args.distributed else ssd300,
                            args.checkpoint)
            checkpoint = torch.load(args.checkpoint,
                                    map_location=lambda storage, loc: storage.
                                    cuda(torch.cuda.current_device()))
            start_epoch = checkpoint['epoch']
            iteration = checkpoint['iteration']
            scheduler.load_state_dict(checkpoint['scheduler'])
            optimizer.load_state_dict(checkpoint['optimizer'])
        else:
            print('Provided checkpoint is not path to a file')
            return

    inv_map = {v: k for k, v in val_dataset.label_map.items()}

    total_time = 0

    if args.mode == 'evaluation':
        acc = evaluate(ssd300, val_dataloader, cocoGt, encoder, inv_map, args)
        if args.local_rank == 0:
            print('Model precision {} mAP'.format(acc))

        return
    mean, std = generate_mean_std(args)

    meters = {
        'total': AverageValueMeter(),
        'ssd': AverageValueMeter(),
        'da': AverageValueMeter()
    }

    vis = Visualizer(env='da ssd', port=6006)

    for epoch in range(start_epoch, args.epochs):
        start_epoch_time = time.time()
        scheduler.step()
        iteration = train_loop_func(ssd300, loss_func, da_loss_func, epoch,
                                    optimizer, train_loader, target_loader,
                                    encoder, iteration, logger, args, mean,
                                    std, meters, vis)
        end_epoch_time = time.time() - start_epoch_time
        total_time += end_epoch_time

        if args.local_rank == 0:
            logger.update_epoch_time(epoch, end_epoch_time)

        if epoch in args.evaluation:
            acc = evaluate(ssd300, val_dataloader, cocoGt, encoder, inv_map,
                           args)

            if args.local_rank == 0:
                logger.update_epoch(epoch, acc)
                vis.log(acc, win='Evaluation')

        if args.save and args.local_rank == 0:
            print("saving model...")
            obj = {
                'epoch': epoch + 1,
                'iteration': iteration,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'label_map': val_dataset.label_info
            }
            if args.distributed:
                obj['model'] = ssd300.module.state_dict()
            else:
                obj['model'] = ssd300.state_dict()
            torch.save(obj, './models/epoch_{}.pt'.format(epoch))
        train_loader.reset()
        target_loader.reset()

    print('total training time: {}'.format(total_time))
Пример #17
0
def main(opt):
    if torch.cuda.is_available():
        torch.cuda.manual_seed(123)
    else:
        torch.manual_seed(123)

    train_params = {
        "batch_size": opt.batch_size,
        "shuffle": True,
        "drop_last": False,
        "num_workers": opt.num_workers,
        "collate_fn": collate_fn
    }

    eval_params = {
        "batch_size": opt.batch_size,
        "shuffle": True,
        "drop_last": False,
        "num_workers": opt.num_workers,
        "collate_fn": collate_fn
    }

    dboxes = generate_dboxes()
    model = SSD()
    train_set = OIDataset(SimpleTransformer(dboxes), train=True)
    train_loader = DataLoader(train_set, **train_params)
    val_set = OIDataset(SimpleTransformer(dboxes, eval=True), validation=True)
    val_loader = DataLoader(val_set, **eval_params)

    encoder = Encoder(dboxes)

    opt.lr = opt.lr * (opt.batch_size / 32)
    criterion = Loss(dboxes)

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=opt.lr,
                                momentum=opt.momentum,
                                weight_decay=opt.weight_decay,
                                nesterov=True)
    scheduler = MultiStepLR(optimizer=optimizer,
                            milestones=opt.multistep,
                            gamma=0.1)

    if torch.cuda.is_available():
        model.cuda()
        criterion.cuda()

    model = torch.nn.DataParallel(model)

    if os.path.isdir(opt.log_path):
        shutil.rmtree(opt.log_path)
    os.makedirs(opt.log_path)

    if not os.path.isdir(opt.save_folder):
        os.makedirs(opt.save_folder)
    checkpoint_path = os.path.join(opt.save_folder, "SSD.pth")

    writer = SummaryWriter(opt.log_path)

    if os.path.isfile(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        first_epoch = checkpoint["epoch"] + 1
        model.module.load_state_dict(checkpoint["model_state_dict"])
        scheduler.load_state_dict(checkpoint["scheduler"])
        optimizer.load_state_dict(checkpoint["optimizer"])
    else:
        first_epoch = 0

    for epoch in range(first_epoch, opt.epochs):
        train(model, train_loader, epoch, writer, criterion, optimizer,
              scheduler)
        evaluate(model, val_loader, encoder, opt.nms_threshold)

        checkpoint = {
            "epoch": epoch,
            "model_state_dict": model.module.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict()
        }
        torch.save(checkpoint, checkpoint_path)
def run_resnet(args):

    # The Resnet paper states the following transforms are applied on the train set
    train_set = datasets.CIFAR100(
        "./data/",
        train=True,
        download=True,
        transform=transforms.Compose([
            transforms.Normalize(  # pre-computed
                (0.5071, 0.4865, 0.4409), (0.2673, 0.2564, 0.2762)),
            transforms.Pad(4),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32),
            transforms.ToTensor()
        ]))
    test_set = datasets.CIFAR100(
        "./data/",
        train=False,
        download=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(  # pre-computed
                (0.5088, 0.4874, 0.4419), (0.2683, 0.2574, 0.2771))
        ]))

    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.b,
                                               shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=args.b,
                                              shuffle=True)

    checkpoints_dir = "checkpoints/resnet/"
    final_dir = 'models/'

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = Resnet(3).to(device)
    optimizer = optim.SGD(model.parameters(),
                          lr=0.1,
                          weight_decay=0.0001,
                          momentum=0.9)
    loss = nn.CrossEntropyLoss()

    # learning rate should be decayed at 32k and 64k milestones
    scheduler = MultiStepLR(optimizer, milestones=[320, 480], gamma=0.1)

    for epoch in range(0, 640):
        loss_train = 0.0

        for images, labels in train_loader:

            images = images.to(device)
            labels = labels.to(device)

            predictions = model(images)
            batch_loss = loss(predictions, labels)

            optimizer.zero_grad()
            batch_loss.backward()
            optimizer.step()

            loss_train += batch_loss.item()

        print('{} Epoch {}, Training loss {}'.format(
            datetime.datetime.now(), epoch + 1,
            loss_train / len(train_loader)))
        scheduler.step()

        if epoch % 100 == 0 or epoch == 0:
            checkpoint_path = os.path.join(checkpoints_dir,
                                           'epoch_' + str(epoch) + '.pt')
            torch.save(
                {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict()
                }, checkpoint_path)

    model_path = os.path.join(final_dir, 'lenet.pth')
    torch.save(model.state_dict(), model_path)

    model.eval()
    with torch.no_grad():

        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            _, predicted = torch.max(outputs, dim=1)
            total += labels.shape[0]
            correct += (predicted == labels).sum().item()

        print("Accuracy = {}".format(100 * (correct / total)))
Пример #19
0
class Learner:
    def __init__(self):
        self.args = self.parse_command_line()

        self.checkpoint_dir, self.logfile, self.checkpoint_path_validation, self.checkpoint_path_final \
            = get_log_files(self.args.checkpoint_dir, self.args.resume_from_checkpoint, False)

        print_and_log(self.logfile, "Options: %s\n" % self.args)
        print_and_log(self.logfile,
                      "Checkpoint Directory: %s\n" % self.checkpoint_dir)

        self.writer = SummaryWriter()

        #gpu_device = 'cuda:0'
        gpu_device = 'cuda'
        self.device = torch.device(
            gpu_device if torch.cuda.is_available() else 'cpu')
        self.model = self.init_model()
        self.train_set, self.validation_set, self.test_set = self.init_data()

        self.vd = video_reader.VideoDataset(self.args)
        self.video_loader = torch.utils.data.DataLoader(
            self.vd, batch_size=1, num_workers=self.args.num_workers)

        self.loss = loss
        self.accuracy_fn = aggregate_accuracy

        if self.args.opt == "adam":
            self.optimizer = torch.optim.Adam(self.model.parameters(),
                                              lr=self.args.learning_rate)
        elif self.args.opt == "sgd":
            self.optimizer = torch.optim.SGD(self.model.parameters(),
                                             lr=self.args.learning_rate)
        self.test_accuracies = TestAccuracies(self.test_set)

        self.scheduler = MultiStepLR(self.optimizer,
                                     milestones=self.args.sch,
                                     gamma=0.1)

        self.start_iteration = 0
        if self.args.resume_from_checkpoint:
            self.load_checkpoint()
        self.optimizer.zero_grad()

    def init_model(self):
        model = CNN_TRX(self.args)
        model = model.to(self.device)
        if self.args.num_gpus > 1:
            model.distribute_model()
        return model

    def init_data(self):
        train_set = [self.args.dataset]
        validation_set = [self.args.dataset]
        test_set = [self.args.dataset]

        return train_set, validation_set, test_set

    """
    Command line parser
    """

    def parse_command_line(self):
        parser = argparse.ArgumentParser()

        parser.add_argument("--dataset",
                            choices=["ssv2", "kinetics"],
                            default="ssv2",
                            help="Dataset to use.")
        parser.add_argument("--learning_rate",
                            "-lr",
                            type=float,
                            default=0.001,
                            help="Learning rate.")
        parser.add_argument(
            "--tasks_per_batch",
            type=int,
            default=16,
            help="Number of tasks between parameter optimizations.")
        parser.add_argument("--checkpoint_dir",
                            "-c",
                            default=None,
                            help="Directory to save checkpoint to.")
        parser.add_argument("--test_model_path",
                            "-m",
                            default=None,
                            help="Path to model to load and test.")
        parser.add_argument("--training_iterations",
                            "-i",
                            type=int,
                            default=50020,
                            help="Number of meta-training iterations.")
        parser.add_argument("--resume_from_checkpoint",
                            "-r",
                            dest="resume_from_checkpoint",
                            default=False,
                            action="store_true",
                            help="Restart from latest checkpoint.")
        parser.add_argument("--way",
                            type=int,
                            default=5,
                            help="Way of single dataset task.")
        parser.add_argument(
            "--shot",
            type=int,
            default=1,
            help="Shots per class for context of single dataset task.")
        parser.add_argument("--query_per_class",
                            type=int,
                            default=5,
                            help="Target samples (i.e. queries) per class.")

        parser.add_argument("--seq_len",
                            type=int,
                            default=8,
                            help="Frames per video.")
        parser.add_argument("--num_workers",
                            type=int,
                            default=10,
                            help="Num dataloader workers.")
        parser.add_argument("--method",
                            choices=["resnet18", "resnet34", "resnet50"],
                            default="resnet50",
                            help="method")
        parser.add_argument("--trans_linear_out_dim",
                            type=int,
                            default=1152,
                            help="Transformer linear_out_dim")
        parser.add_argument("--opt",
                            choices=["adam", "sgd"],
                            default="sgd",
                            help="Optimizer")
        parser.add_argument("--trans_dropout",
                            type=int,
                            default=0.1,
                            help="Transformer dropout")
        parser.add_argument(
            "--save_freq",
            type=int,
            default=5000,
            help="Number of iterations between checkpoint saves.")
        parser.add_argument("--img_size",
                            type=int,
                            default=224,
                            help="Input image size to the CNN after cropping.")
        parser.add_argument('--temp_set',
                            nargs='+',
                            type=int,
                            help='cardinalities e.g. 2,3 is pairs and triples',
                            default=[2, 3])

        parser.add_argument("--scratch",
                            choices=["bc", "bp"],
                            default="bp",
                            help="Computer to run on")
        parser.add_argument("--num_gpus",
                            type=int,
                            default=1,
                            help="Number of GPUs to split the ResNet over")
        parser.add_argument("--debug_loader",
                            default=False,
                            action="store_true",
                            help="Load 1 vid per class for debugging")

        parser.add_argument("--split",
                            type=int,
                            default=3,
                            help="Dataset split.")
        parser.add_argument('--sch',
                            nargs='+',
                            type=int,
                            help='iters to drop learning rate',
                            default=[1000000])

        args = parser.parse_args()

        if args.scratch == "bc":
            args.scratch = "/mnt/storage/home/tp8961/scratch"
        elif args.scratch == "bp":
            args.num_gpus = 4
            args.num_workers = 5
            args.scratch = "/work/tp8961"

        if args.checkpoint_dir == None:
            print("need to specify a checkpoint dir")
            exit(1)

        if (args.method == "resnet50") or (args.method == "resnet34"):
            args.img_size = 224
        if args.method == "resnet50":
            args.trans_linear_in_dim = 2048
        else:
            args.trans_linear_in_dim = 512

        if args.dataset == "ssv2":
            args.traintestlist = os.path.join(
                args.scratch,
                "video_datasets/splits/somethingsomethingv2TrainTestlist")
            args.path = os.path.join(
                args.scratch,
                "video_datasets/data/somethingsomethingv2_256x256q5_1.zip")
        elif args.dataset == "kinetics":
            args.traintestlist = os.path.join(
                args.scratch, "video_datasets/splits/kineticsTrainTestlist")
            args.path = os.path.join(
                args.scratch, "video_datasets/data/kinetics_256q5_1.zip")
        return args

    def run(self):
        config = tf.compat.v1.ConfigProto()
        config.gpu_options.allow_growth = True
        with tf.compat.v1.Session(config=config) as session:
            train_accuracies = []
            losses = []
            total_iterations = self.args.training_iterations

            iteration = self.start_iteration
            for task_dict in self.video_loader:
                if iteration >= total_iterations:
                    break
                iteration += 1
                torch.set_grad_enabled(True)

                task_loss, task_accuracy = self.train_task(task_dict)
                train_accuracies.append(task_accuracy)
                losses.append(task_loss)

                # optimize
                if ((iteration + 1) % self.args.tasks_per_batch
                        == 0) or (iteration == (total_iterations - 1)):
                    self.optimizer.step()
                    self.optimizer.zero_grad()
                self.scheduler.step()
                if (iteration + 1) % PRINT_FREQUENCY == 0:
                    # print training stats
                    print_and_log(
                        self.logfile,
                        'Task [{}/{}], Train Loss: {:.7f}, Train Accuracy: {:.7f}'
                        .format(iteration + 1, total_iterations,
                                torch.Tensor(losses).mean().item(),
                                torch.Tensor(train_accuracies).mean().item()))
                    train_accuracies = []
                    losses = []

                if ((iteration + 1) % self.args.save_freq
                        == 0) and (iteration + 1) != total_iterations:
                    self.save_checkpoint(iteration + 1)

                if ((iteration + 1)
                        in TEST_ITERS) and (iteration + 1) != total_iterations:
                    accuracy_dict = self.test(session)
                    self.test_accuracies.print(self.logfile, accuracy_dict)

            # save the final model
            torch.save(self.model.state_dict(), self.checkpoint_path_final)

        self.logfile.close()

    def train_task(self, task_dict):
        context_images, target_images, context_labels, target_labels, real_target_labels, batch_class_list = self.prepare_task(
            task_dict)

        model_dict = self.model(context_images, context_labels, target_images)
        target_logits = model_dict['logits']

        task_loss = self.loss(target_logits, target_labels,
                              self.device) / self.args.tasks_per_batch
        task_accuracy = self.accuracy_fn(target_logits, target_labels)

        task_loss.backward(retain_graph=False)

        return task_loss, task_accuracy

    def test(self, session):
        self.model.eval()
        with torch.no_grad():

            self.video_loader.dataset.train = False
            accuracy_dict = {}
            accuracies = []
            iteration = 0
            item = self.args.dataset
            for task_dict in self.video_loader:
                if iteration >= NUM_TEST_TASKS:
                    break
                iteration += 1

                context_images, target_images, context_labels, target_labels, real_target_labels, batch_class_list = self.prepare_task(
                    task_dict)
                model_dict = self.model(context_images, context_labels,
                                        target_images)
                target_logits = model_dict['logits']
                accuracy = self.accuracy_fn(target_logits, target_labels)
                accuracies.append(accuracy.item())
                del target_logits

            accuracy = np.array(accuracies).mean() * 100.0
            confidence = (196.0 * np.array(accuracies).std()) / np.sqrt(
                len(accuracies))

            accuracy_dict[item] = {
                "accuracy": accuracy,
                "confidence": confidence
            }
            self.video_loader.dataset.train = True
        self.model.train()

        return accuracy_dict

    def prepare_task(self, task_dict, images_to_device=True):
        context_images, context_labels = task_dict['support_set'][
            0], task_dict['support_labels'][0]
        target_images, target_labels = task_dict['target_set'][0], task_dict[
            'target_labels'][0]
        real_target_labels = task_dict['real_target_labels'][0]
        batch_class_list = task_dict['batch_class_list'][0]

        if images_to_device:
            context_images = context_images.to(self.device)
            target_images = target_images.to(self.device)
        context_labels = context_labels.to(self.device)
        target_labels = target_labels.type(torch.LongTensor).to(self.device)

        return context_images, target_images, context_labels, target_labels, real_target_labels, batch_class_list

    def shuffle(self, images, labels):
        """
        Return shuffled data.
        """
        permutation = np.random.permutation(images.shape[0])
        return images[permutation], labels[permutation]

    def save_checkpoint(self, iteration):
        d = {
            'iteration': iteration,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict()
        }

        torch.save(
            d,
            os.path.join(self.checkpoint_dir,
                         'checkpoint{}.pt'.format(iteration)))
        torch.save(d, os.path.join(self.checkpoint_dir, 'checkpoint.pt'))

    def load_checkpoint(self):
        checkpoint = torch.load(
            os.path.join(self.checkpoint_dir, 'checkpoint.pt'))
        self.start_iteration = checkpoint['iteration']
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler'])
def main(args: argparse.Namespace):
    logger = CompleteLogger(args.log, args.phase)
    print(args)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    cudnn.benchmark = True

    # Data loading code
    normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    train_transform = T.Compose([
        T.RandomRotation(args.rotation),
        T.RandomResizedCrop(size=args.image_size, scale=args.resize_scale),
        T.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25),
        T.GaussianBlur(),
        T.ToTensor(), normalize
    ])
    val_transform = T.Compose(
        [T.Resize(args.image_size),
         T.ToTensor(), normalize])
    image_size = (args.image_size, args.image_size)
    heatmap_size = (args.heatmap_size, args.heatmap_size)
    source_dataset = datasets.__dict__[args.source]
    train_source_dataset = source_dataset(root=args.source_root,
                                          transforms=train_transform,
                                          image_size=image_size,
                                          heatmap_size=heatmap_size)
    train_source_loader = DataLoader(train_source_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     drop_last=True)
    val_source_dataset = source_dataset(root=args.source_root,
                                        split='test',
                                        transforms=val_transform,
                                        image_size=image_size,
                                        heatmap_size=heatmap_size)
    val_source_loader = DataLoader(val_source_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=False,
                                   pin_memory=True)

    target_dataset = datasets.__dict__[args.target]
    train_target_dataset = target_dataset(root=args.target_root,
                                          transforms=train_transform,
                                          image_size=image_size,
                                          heatmap_size=heatmap_size)
    train_target_loader = DataLoader(train_target_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     drop_last=True)
    val_target_dataset = target_dataset(root=args.target_root,
                                        split='test',
                                        transforms=val_transform,
                                        image_size=image_size,
                                        heatmap_size=heatmap_size)
    val_target_loader = DataLoader(val_target_dataset,
                                   batch_size=args.batch_size,
                                   shuffle=False,
                                   pin_memory=True)

    print("Source train:", len(train_source_loader))
    print("Target train:", len(train_target_loader))
    print("Source test:", len(val_source_loader))
    print("Target test:", len(val_target_loader))

    train_source_iter = ForeverDataIterator(train_source_loader)
    train_target_iter = ForeverDataIterator(train_target_loader)

    # create model
    model = models.__dict__[args.arch](
        num_keypoints=train_source_dataset.num_keypoints).to(device)
    criterion = JointsMSELoss()

    # define optimizer and lr scheduler
    optimizer = Adam(model.get_parameters(lr=args.lr))
    lr_scheduler = MultiStepLR(optimizer, args.lr_step, args.lr_factor)

    # optionally resume from a checkpoint
    start_epoch = 0
    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        start_epoch = checkpoint['epoch'] + 1

    # define visualization function
    tensor_to_image = Compose([
        Denormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ToPILImage()
    ])

    def visualize(image, keypoint2d, name):
        """
        Args:
            image (tensor): image in shape 3 x H x W
            keypoint2d (tensor): keypoints in shape K x 2
            name: name of the saving image
        """
        train_source_dataset.visualize(
            tensor_to_image(image), keypoint2d,
            logger.get_image_path("{}.jpg".format(name)))

    if args.phase == 'test':
        # evaluate on validation set
        source_val_acc = validate(val_source_loader, model, criterion, None,
                                  args)
        target_val_acc = validate(val_target_loader, model, criterion,
                                  visualize, args)
        print("Source: {:4.3f} Target: {:4.3f}".format(source_val_acc['all'],
                                                       target_val_acc['all']))
        for name, acc in target_val_acc.items():
            print("{}: {:4.3f}".format(name, acc))
        return

    # start training
    best_acc = 0
    for epoch in range(start_epoch, args.epochs):
        logger.set_epoch(epoch)
        lr_scheduler.step()

        # train for one epoch
        train(train_source_iter, train_target_iter, model, criterion,
              optimizer, epoch, visualize if args.debug else None, args)

        # evaluate on validation set
        source_val_acc = validate(val_source_loader, model, criterion, None,
                                  args)
        target_val_acc = validate(val_target_loader, model, criterion,
                                  visualize if args.debug else None, args)

        # remember best acc and save checkpoint
        torch.save(
            {
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lr_scheduler': lr_scheduler.state_dict(),
                'epoch': epoch,
                'args': args
            }, logger.get_checkpoint_path(epoch))
        if target_val_acc['all'] > best_acc:
            shutil.copy(logger.get_checkpoint_path(epoch),
                        logger.get_checkpoint_path('best'))
            best_acc = target_val_acc['all']
        print("Source: {:4.3f} Target: {:4.3f} Target(best): {:4.3f}".format(
            source_val_acc['all'], target_val_acc['all'], best_acc))
        for name, acc in target_val_acc.items():
            print("{}: {:4.3f}".format(name, acc))

    logger.close()
Пример #21
0
def main():
    start_epoch = 0
    best_prec1 = 0.0

    seed=np.random.randint(10000)

    if seed is not None:
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

    if args.gpus is not None:
        device = torch.device("cuda:{}".format(args.gpus[0]))
        cudnn.benchmark = False
        # cudnn.deterministic = True
        cudnn.enabled = True 
    else:
        device = torch.device("cpu")
    
    now = datetime.now().strftime('%Y-%m-%d-%H:%M:%S')
    if args.mission is not None:
        if 'vgg' == args.arch and args.batchnorm:
            args.job_dir = f'{args.job_dir}/{args.dataset}/{args.arch}{args.num_layers}_bn/{args.mission}/{now}'
        elif 'resnet20' == args.arch:
            args.job_dir = f'{args.job_dir}/{args.dataset}/{args.arch}/{args.mission}/{now}'
        else:
            args.job_dir = f'{args.job_dir}/{args.dataset}/{args.arch}{args.num_layers}/{args.mission}/{now}'

    else:
        if 'vgg' == args.arch and args.batchnorm:
            args.job_dir = f'{args.job_dir}/{args.dataset}/{args.arch}{args.num_layers}_bn/{now}'
        else:
            args.job_dir = f'{args.job_dir}/{args.dataset}/{args.arch}{args.num_layers}/{now}'
    
    _make_dir(args.job_dir)
    ckpt = utils.checkpoint(args)
    print_logger = utils.get_logger(os.path.join(args.job_dir, "logger.log"))
    utils.print_params(vars(args), print_logger.info)
    writer_train = SummaryWriter(args.job_dir +'/run/train')
    writer_test = SummaryWriter(args.job_dir+ '/run/test')

    ## hyperparameters settings ##
    n_layers = (args.num_layers - 2) * 2 
    unit_k_bits = int(args.k_bits)
    kbits_list = [unit_k_bits for i in range(n_layers)]
    print_logger.info(f'k_bits_list {kbits_list}')

    # Data loading
    print('=> Preparing data..')

    if args.dataset in ['cifar10', 'cifar100','mnist']:
        IMAGE_SIZE = 32
    elif args.dataset == 'tinyimagenet':
        IMAGE_SIZE = 64
    else:
        IMAGE_SIZE = 224

    if args.dataset == 'imagenet':
        train_loader = get_imagenet_iter_dali(type = 'train',image_dir=args.data_dir, batch_size=args.train_batch_size,num_threads=args.workers,crop=IMAGE_SIZE,device_id=0,num_gpus=1)
        val_loader = get_imagenet_iter_dali(type='val', image_dir=args.data_dir, batch_size=args.eval_batch_size,num_threads=args.workers,crop=IMAGE_SIZE,device_id=0,num_gpus=1)
    elif args.dataset == 'tinyimagenet':
        train_loader = get_imagenet_iter_dali(type = 'train',image_dir=args.data_dir, batch_size=args.train_batch_size,num_threads=args.workers,crop=IMAGE_SIZE,device_id=0,num_gpus=1)
        val_loader = get_imagenet_iter_dali(type='val', image_dir=args.data_dir, batch_size=args.eval_batch_size,num_threads=args.workers,crop=IMAGE_SIZE,device_id=0,num_gpus=1)
    elif args.dataset == 'cifar10':
        train_loader = get_cifar_iter_dali(type='train', image_dir=args.data_dir, batch_size=args.train_batch_size,num_threads=args.workers)
        val_loader = get_cifar_iter_dali(type='val', image_dir=args.data_dir, batch_size=args.eval_batch_size,num_threads=args.workers)

    # Create model
    print('=> Building model...')
    if args.dataset =='cifar10':
        num_classes = 10
        train_data_length = 50000
        eval_data_length =10000
    elif args.dataset == 'imagenet':
        num_classes = 1000
        train_data_length = 50000
        eval_data_length =10000

    # arch = args.arch
    # model = models.__dict__[arch]

    model_config = {'k_bits':kbits_list,'num_layers':args.num_layers,'pre_k_bits':args.pre_k_bits,'ratio':args.ratio}
    if args.arch == 'mobilenetv2':
        model_config = {'k_bits':kbits_list,'num_layers':args.num_layers,'pre_k_bits':args.pre_k_bits,'ratio':args.ratio,'width_mult':args.width_mult}
    if 'vgg' == args.arch and args.batchnorm:
        model,model_k_bits = import_module(f"models.{args.dataset}.{args.archtype}.{args.arch}").__dict__[f'{args.arch}{args.num_layers}_bn'](model_config)
    elif 'resnet20' == args.arch:
        model,model_k_bits = import_module(f"models.{args.dataset}.{args.archtype}.{args.arch}").__dict__[f'{args.arch}'](model_config)
    else:
        model,model_k_bits = import_module(f"models.{args.dataset}.{args.archtype}.{args.arch}").__dict__[f'{args.arch}{args.num_layers}'](model_config)

    model = model.to(device)
    print_logger.info(f'model_k_bits_list {model_k_bits}')
    # Define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    scheduler = MultiStepLR(optimizer, milestones=[0.5 * args.train_epochs, 0.75 * args.train_epochs], gamma=0.1)
  
    # Optionally resume from a checkpoint
    resume = args.resume
    if resume:
        print('=> Loading checkpoint {}'.format(resume))
        checkpoint = torch.load(resume, map_location=device)
        state_dict = checkpoint['state_dict']
        start_epoch = checkpoint['epoch']
        pre_train_best_prec1 = checkpoint['best_prec1']
        model_check = load_check(state_dict,model)
        pdb.set_trace()
        model.load_state_dict(model_check)
        print('Prec@1:',pre_train_best_prec1)

    if args.test_only:
        test_prec1 = test(args, device, val_loader, model, criterion, writer_test,print_logger,start_epoch )
        print('=> Test Prec@1: {:.2f}'.format(test_prec1))
        print(f'sample k_bits {kbits_list}')
        return

    for epoch in range(0, args.train_epochs):
        scheduler.step(epoch)
        train_loss, train_prec1 = train(args, device, train_loader, train_data_length, model, criterion, optimizer, writer_train, print_logger, epoch)
        test_prec1 = test(args, device, val_loader, eval_data_length, model, criterion, writer_test, print_logger, epoch)

        is_best = best_prec1 < test_prec1
        best_prec1 = max(test_prec1, best_prec1) 

        state = {
                'state_dict': model.state_dict(),
                'test_prec1': test_prec1, 
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'epoch': epoch + 1
            }
        ckpt.save_model(state, epoch + 1, is_best,mode='train')
        print_logger.info('==> BEST ACC {:.3f}'.format(best_prec1.item()))
def main():
    # Views the training images and displays the distance on anchor-negative and anchor-positive
    # print the experiment configuration
    print('\nCurrent time is \33[91m{}\33[0m'.format(str(time.asctime())))
    print('Parsed options: {}'.format(vars(args)))
    print('Number of Classes: {}\n'.format(len(train_dir.speakers)))

    # instantiate
    # model and initialize weights
    model = AttenSiResNet(layers=[3, 4, 6, 3], num_classes=len(train_dir.speakers))

    if args.cuda:
        model.cuda()

    optimizer = create_optimizer(model.parameters(), args.optimizer, **opt_kwargs)
    scheduler = MultiStepLR(optimizer, milestones=[18, 24], gamma=0.1)
    # criterion = AngularSoftmax(in_feats=args.embedding_size,
    #                           num_classes=len(train_dir.classes))
    start = 0
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print('=> loading checkpoint {}'.format(args.resume))
            checkpoint = torch.load(args.resume)
            start = checkpoint['epoch']
            checkpoint = torch.load(args.resume)
            filtered = {k: v for k, v in checkpoint['state_dict'].items() if 'num_batches_tracked' not in k}
            model.load_state_dict(filtered)
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            # criterion.load_state_dict(checkpoint['criterion'])
        else:
            print('=> no checkpoint found at {}'.format(args.resume))

    start += args.start_epoch
    print('Start epoch is : ' + str(start))
    end = start + args.epochs

    # pdb.set_trace()
    train_loader = torch.utils.data.DataLoader(train_dir, batch_size=args.batch_size,
                                               # collate_fn=PadCollate(dim=2, fix_len=True),
                                               shuffle=True, **kwargs)
    valid_loader = torch.utils.data.DataLoader(valid_dir, batch_size=args.batch_size,
                                               # collate_fn=PadCollate(dim=2, fix_len=True),
                                               shuffle=False, **kwargs)
    test_loader = torch.utils.data.DataLoader(test_part, batch_size=args.test_batch_size, shuffle=False, **kwargs)

    criterion = nn.CrossEntropyLoss().cuda()
    check_path = '{}/checkpoint_{}.pth'.format(args.check_path, -1)
    torch.save({'epoch': -1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict()},
               # 'criterion': criterion.state_dict()
               check_path)

    for epoch in range(start, end):
        # pdb.set_trace()
        for param_group in optimizer.param_groups:
            print('\n\33[1;34m Current \'{}\' learning rate is {}.\33[0m'.format(args.optimizer, param_group['lr']))

        train(train_loader, model, optimizer, criterion, scheduler, epoch)
        test(test_loader, valid_loader, model, epoch)

        scheduler.step()
        # break

    writer.close()
Пример #23
0
def main(opt):
    if torch.cuda.is_available():
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        num_gpus = torch.distributed.get_world_size()
        torch.cuda.manual_seed(123)
    else:
        torch.manual_seed(123)
        num_gpus = 1

    train_params = {
        "batch_size": opt.batch_size * num_gpus,
        "shuffle": True,
        "drop_last": False,
        "num_workers": opt.num_workers,
        "collate_fn": collate_fn
    }

    test_params = {
        "batch_size": opt.batch_size * num_gpus,
        "shuffle": False,
        "drop_last": False,
        "num_workers": opt.num_workers,
        "collate_fn": collate_fn
    }

    if opt.model == "ssd":
        dboxes = generate_dboxes(model="ssd")
        model = SSD(backbone=ResNet(), num_classes=len(coco_classes))
    else:
        dboxes = generate_dboxes(model="ssdlite")
        model = SSDLite(backbone=MobileNetV2(), num_classes=len(coco_classes))
    train_set = CocoDataset(opt.data_path, 2017, "train",
                            SSDTransformer(dboxes, (300, 300), val=False))
    train_loader = DataLoader(train_set, **train_params)
    test_set = CocoDataset(opt.data_path, 2017, "val",
                           SSDTransformer(dboxes, (300, 300), val=True))
    test_loader = DataLoader(test_set, **test_params)

    encoder = Encoder(dboxes)

    opt.lr = opt.lr * num_gpus * (opt.batch_size / 32)
    criterion = Loss(dboxes)

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=opt.lr,
                                momentum=opt.momentum,
                                weight_decay=opt.weight_decay,
                                nesterov=True)
    scheduler = MultiStepLR(optimizer=optimizer,
                            milestones=opt.multistep,
                            gamma=0.1)

    if torch.cuda.is_available():
        model.cuda()
        criterion.cuda()

        if opt.amp:
            from apex import amp
            from apex.parallel import DistributedDataParallel as DDP
            model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
        else:
            from torch.nn.parallel import DistributedDataParallel as DDP
        # It is recommended to use DistributedDataParallel, instead of DataParallel
        # to do multi-GPU training, even if there is only a single node.
        model = DDP(model)

    if os.path.isdir(opt.log_path):
        shutil.rmtree(opt.log_path)
    os.makedirs(opt.log_path)

    if not os.path.isdir(opt.save_folder):
        os.makedirs(opt.save_folder)
    checkpoint_path = os.path.join(opt.save_folder, "SSD.pth")

    writer = SummaryWriter(opt.log_path)

    if os.path.isfile(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        first_epoch = checkpoint["epoch"] + 1
        model.module.load_state_dict(checkpoint["model_state_dict"])
        scheduler.load_state_dict(checkpoint["scheduler"])
        optimizer.load_state_dict(checkpoint["optimizer"])
    else:
        first_epoch = 0

    for epoch in range(first_epoch, opt.epochs):
        train(model, train_loader, epoch, writer, criterion, optimizer,
              scheduler, opt.amp)
        evaluate(model, test_loader, epoch, writer, encoder, opt.nms_threshold)

        checkpoint = {
            "epoch": epoch,
            "model_state_dict": model.module.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict()
        }
        torch.save(checkpoint, checkpoint_path)
Пример #24
0
                mean_generator_adversarial_loss / len(train_dataloader), mean_generator_content_loss /
                len(train_dataloader), mean_generator_total_loss / len(train_dataloader)))

        log_value('generator_perceptual_loss', mean_generator_perceptual_loss / len(train_dataloader), epoch)
        log_value('generator_adversarial_loss', mean_generator_adversarial_loss / len(train_dataloader), epoch)
        log_value('generator_content_loss', mean_generator_content_loss / len(train_dataloader), epoch)
        log_value('generator_total_loss', mean_generator_total_loss / len(train_dataloader), epoch)
        log_value('discriminator_loss', mean_discriminator_loss / len(train_dataloader), epoch)

        scheduler_generator.step()
        scheduler_discriminator.step()

        # Do checkpointing 保存模型
        generator_state = {'generator_model': generator.state_dict(),
                           'generator_optimizer': optim_generator.state_dict(),
                           'scheduler_generator': scheduler_generator.state_dict(), 'epoch': epoch}
        discriminator_state = {'discriminator_model': discriminator.state_dict(), 'discriminator_optimizer':
                               optim_discriminator.state_dict(), 'scheduler_discriminator':
                               scheduler_discriminator.state_dict(), 'epoch': epoch}

        # save model
        torch.save(generator_state, opt.generatorWeights)
        torch.save(discriminator_state, opt.discriminatorWeights)

        if epoch % 5 == 0:
            # 验证集
            out_path = 'pretraining_results/SRF_' + str(opt.upSampling) + '/'
            if not os.path.exists(out_path):
                os.makedirs(out_path)

            with torch.no_grad():
Пример #25
0
def main_train(args):
    # 获取命令参数
    if args.resume_training is not None:
        if not os.path.isfile(args.resume_training):
            print(f"{args.resume_training} 不是一个合法的文件!")
            return
        else:
            print(f"加载检查点:{args.resume_training}")
    cuda = args.cuda
    resume = args.resume_training
    batch_size = args.batch_size
    milestones = args.milestones
    lr = args.lr
    total_epoch = args.epochs
    resume_checkpoint_filename = args.resume_training
    best_model_name = args.best_model_name
    checkpoint_name = args.best_model_name
    data_path = args.data_path
    start_epoch = 1

    print("加载数据....")
    dataset = ISONetData(data_path=data_path)
    dataset_test = ISONetData(data_path=data_path, train=False)
    data_loader = DataLoader(dataset=dataset,
                             batch_size=batch_size,
                             shuffle=True,
                             num_workers=6,
                             pin_memory=True)
    data_loader_test = DataLoader(dataset=dataset_test,
                                  batch_size=batch_size,
                                  shuffle=False)
    print("成功加载数据...")
    print(f"训练集数量: {len(dataset)}")
    print(f"验证集数量: {len(dataset_test)}")

    model_path = Path("models")
    checkpoint_path = model_path.joinpath("checkpoint")

    if not model_path.exists():
        model_path.mkdir()
    if not checkpoint_path.exists():
        checkpoint_path.mkdir()

    if torch.cuda.is_available():
        device = torch.cuda.current_device()
    else:
        print("cuda 无效!")
        cuda = False

    net = ISONet()
    criterion = nn.MSELoss(reduction="mean")
    optimizer = optim.Adam(net.parameters(), lr=lr)

    if cuda:
        net = net.to(device=device)
        criterion = criterion.to(device=device)

    scheduler = MultiStepLR(optimizer=optimizer,
                            milestones=milestones,
                            gamma=0.1)
    writer = SummaryWriter()

    # 恢复训练
    if resume:
        print("恢复训练中...")
        checkpoint = torch.load(
            checkpoint_path.joinpath(resume_checkpoint_filename))
        net.load_state_dict(checkpoint["net"])
        optimizer.load_state_dict((checkpoint["optimizer"]))
        scheduler.load_state_dict(checkpoint["scheduler"])
        resume_epoch = checkpoint["epoch"]
        best_test_loss = checkpoint["best_test_loss"]

        start_epoch = resume_epoch + 1
        print(f"从第[{start_epoch}]轮开始训练...")
        print(f"上一次的损失为: [{best_test_loss}]...")
    else:
        # 初始化权重
        for m in net.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    if not locals().get("best_test_loss"):
        best_test_loss = 0

    record = 0
    for epoch in range(start_epoch, total_epoch):
        print(f"开始第 [{epoch}] 轮训练...")
        net.train()
        writer.add_scalar("Train/Learning Rate",
                          scheduler.get_last_lr()[0], epoch)
        for i, (data, label) in enumerate(data_loader, 0):
            if i == 0:
                start_time = int(time.time())
            if cuda:
                data = data.to(device=device)
                label = label.to(device=device)
            label = label.unsqueeze(1)

            optimizer.zero_grad()

            output = net(data)

            loss = criterion(output, label)

            loss.backward()

            optimizer.step()
            if i % 500 == 499:
                end_time = int(time.time())
                use_time = end_time - start_time

                print(
                    f">>> epoch[{epoch}] loss[{loss:.4f}]  {i * batch_size}/{len(dataset)} lr{scheduler.get_last_lr()} ",
                    end="")
                left_time = ((len(dataset) - i * batch_size) / 500 /
                             batch_size) * (end_time - start_time)
                print(
                    f"耗费时间:[{end_time - start_time:.2f}]秒,估计剩余时间: [{left_time:.2f}]秒"
                )
                start_time = end_time
            # 记录到 tensorboard
            if i % 128 == 127:
                writer.add_scalar("Train/loss", loss, record)
                record += 1

        # validate
        print("测试模型...")
        net.eval()

        test_loss = 0
        with torch.no_grad():
            loss_t = nn.MSELoss(reduction="mean")
            if cuda:
                loss_t = loss_t.to(device)
            for data, label in data_loader_test:
                if cuda:
                    data = data.to(device)
                    label = label.to(device)
                # expand dim
                label = label.unsqueeze_(1)
                predict = net(data)
                # sum up batch loss
                test_loss += loss_t(predict, label).item()

        test_loss /= len(dataset_test)
        test_loss *= batch_size
        print(
            f'\nTest Data: Average batch[{batch_size}] loss: {test_loss:.4f}\n'
        )
        scheduler.step()

        writer.add_scalar("Test/Loss", test_loss, epoch)

        checkpoint = {
            "net": net.state_dict(),
            "optimizer": optimizer.state_dict(),
            "epoch": epoch,
            "scheduler": scheduler.state_dict(),
            "best_test_loss": best_test_loss
        }

        if best_test_loss == 0:
            print("保存模型中...")
            torch.save(net.state_dict(), model_path.joinpath(best_model_name))
            best_test_loss = test_loss
        else:
            # 保存更好的模型
            if test_loss < best_test_loss:
                print("获取到更好的模型,保存中...")
                torch.save(net.state_dict(),
                           model_path.joinpath(best_model_name))
                best_test_loss = test_loss
        # 保存检查点
        if epoch % args.save_every_epochs == 0:
            c_time = time2str()
            torch.save(
                checkpoint,
                checkpoint_path.joinpath(
                    f"{checkpoint_name}_{epoch}_{c_time}.cpth"))
            print(f"保存检查点: [{checkpoint_name}_{epoch}_{c_time}.cpth]...\n")
Пример #26
0
class AuxModel:
    def __init__(self, config, logger, wandb):
        self.config = config
        self.logger = logger
        self.writer = SummaryWriter(config.log_dir)
        self.wandb = wandb
        cudnn.enabled = True

        # set up model
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.model = get_model(config)
        if len(config.gpus) > 1:
            self.model = nn.DataParallel(self.model)
        self.model = self.model.to(self.device)
        self.best_acc = 0
        self.best_AUC = 0
        self.class_loss_func = nn.CrossEntropyLoss()
        self.pixel_loss = nn.L1Loss()
        if config.mode == 'train':
            # set up optimizer, lr scheduler and loss functions
            lr = config.lr
            self.optimizer = torch.optim.Adam(self.model.parameters(),
                                              lr=lr,
                                              betas=(.5, .999))
            self.scheduler = MultiStepLR(self.optimizer,
                                         milestones=[50, 150],
                                         gamma=0.1)
            self.wandb.watch(self.model)
            self.start_iter = 0

            # resume
            if config.training_resume:
                self.load(config.model_dir + '/' + config.training_resume)

            cudnn.benchmark = True
        elif config.mode == 'val':
            self.load(os.path.join(config.testing_model))
        else:
            self.load(os.path.join(config.testing_model))

    def entropy_loss(self, x):
        return torch.sum(-F.softmax(x, 1) * F.log_softmax(x, 1), 1).mean()

    def train_epoch_main_task(self, src_loader, tar_loader, epoch, print_freq):
        self.model.train()
        batch_time = AverageMeter()
        losses = AverageMeter()
        main_loss = AverageMeter()
        top1 = AverageMeter()

        for it, src_batch in enumerate(src_loader['main_task']):
            t = time.time()
            self.optimizer.zero_grad()
            src = src_batch
            src = to_device(src, self.device)
            src_imgs, src_cls_lbls = src
            self.optimizer.zero_grad()
            src_main_logits = self.model(src_imgs, 'main_task')
            src_main_loss = self.class_loss_func(src_main_logits, src_cls_lbls)
            loss = src_main_loss * self.config.loss_weight['main_task']
            main_loss.update(loss.item(), src_imgs.size(0))
            precision1_train, precision2_train = accuracy(src_main_logits,
                                                          src_cls_lbls,
                                                          topk=(1, 2))
            top1.update(precision1_train[0], src_imgs.size(0))

            loss.backward()
            self.optimizer.step()

            losses.update(loss.item(), src_imgs.size(0))

            # measure elapsed time
            batch_time.update(time.time() - t)

            self.start_iter += 1

            if self.start_iter % print_freq == 0:
                print_string = 'Epoch {:>2} | iter {:>4} | loss:{:.3f}| acc:{:.3f}| src_main: {:.3f} |' + '|{:4.2f} s/it'
                self.logger.info(
                    print_string.format(epoch, self.start_iter, losses.avg,
                                        top1.avg, main_loss.avg,
                                        batch_time.avg))
                self.writer.add_scalar('losses/all_loss', losses.avg,
                                       self.start_iter)
                self.writer.add_scalar('losses/src_main_loss', src_main_loss,
                                       self.start_iter)
        self.scheduler.step()
        self.wandb.log({"Train Loss": main_loss.avg})

        # del loss, src_class_loss, src_aux_loss, tar_aux_loss, tar_entropy_loss
        # del src_aux_logits, src_class_logits
        # del tar_aux_logits, tar_class_logits

    def train_epoch_all_tasks(self, src_loader, tar_loader, epoch, print_freq):
        self.model.train()
        batch_time = AverageMeter()
        losses = AverageMeter()
        main_loss = AverageMeter()
        top1 = AverageMeter()
        start_steps = epoch * len(tar_loader['main_task'])
        total_steps = self.config.num_epochs * len(tar_loader['main_task'])

        max_num_iter_src = max([
            len(src_loader[task_name]) for task_name in self.config.task_names
        ])
        for it in range(max_num_iter_src):
            t = time.time()

            # this is based on DANN paper
            p = float(it + start_steps) / total_steps
            alpha = 2. / (1. + np.exp(-10 * p)) - 1

            self.optimizer.zero_grad()

            src = next(iter(src_loader['main_task']))
            tar = next(iter(tar_loader['main_task']))
            src = to_device(src, self.device)
            tar = to_device(tar, self.device)
            src_imgs, src_cls_lbls = src
            tar_imgs, _ = tar

            src_main_logits = self.model(src_imgs, 'main_task')
            src_main_loss = self.class_loss_func(src_main_logits, src_cls_lbls)
            loss = src_main_loss * self.config.loss_weight['main_task']
            main_loss.update(loss.item(), src_imgs.size(0))
            tar_main_logits = self.model(tar_imgs, 'main_task')
            tar_main_loss = self.entropy_loss(tar_main_logits)
            loss += tar_main_loss
            tar_aux_loss = {}
            src_aux_loss = {}

            #TO DO: separating dataloaders and iterate over tasks
            for task in self.config.task_names:
                if self.config.tasks[task]['type'] == 'classification_adapt':
                    r = torch.randperm(src_imgs.size()[0] + tar_imgs.size()[0])
                    src_tar_imgs = torch.cat((src_imgs, tar_imgs), dim=0)
                    src_tar_imgs = src_tar_imgs[r, :, :, :]
                    src_tar_img = src_tar_imgs[:src_imgs.size()[0], :, :, :]
                    src_tar_lbls = torch.cat((torch.zeros(
                        (src_imgs.size()[0])), torch.ones(
                            (tar_imgs.size()[0]))),
                                             dim=0)
                    src_tar_lbls = src_tar_lbls[r]
                    src_tar_lbls = src_tar_lbls[:src_imgs.size()[0]]
                    src_tar_lbls = src_tar_lbls.long().cuda()
                    src_tar_logits = self.model(src_tar_img,
                                                'domain_classifier', alpha)
                    tar_aux_loss['domain_classifier'] = self.class_loss_func(
                        src_tar_logits, src_tar_lbls)
                    loss += tar_aux_loss[
                        'domain_classifier'] * self.config.loss_weight[
                            'domain_classifier']
                if self.config.tasks[task]['type'] == 'classification_self':
                    src = next(iter(src_loader[task]))
                    tar = next(iter(tar_loader[task]))
                    src = to_device(src, self.device)
                    tar = to_device(tar, self.device)
                    src_aux_imgs, src_aux_lbls = src
                    tar_aux_imgs, tar_aux_lbls = tar
                    tar_aux_logits = self.model(tar_aux_imgs, task)
                    src_aux_logits = self.model(src_aux_imgs, task)
                    tar_aux_loss[task] = self.class_loss_func(
                        tar_aux_logits, tar_aux_lbls)
                    src_aux_loss[task] = self.class_loss_func(
                        src_aux_logits, src_aux_lbls)
                    loss += src_aux_loss[task] * self.config.loss_weight[
                        task]  # todo: magnification weight
                    loss += tar_aux_loss[task] * self.config.loss_weight[
                        task]  # todo: main task weight
                if self.config.tasks[task]['type'] == 'pixel_self':
                    src = next(iter(src_loader[task]))
                    tar = next(iter(tar_loader[task]))
                    src = to_device(src, self.device)
                    tar = to_device(tar, self.device)
                    src_aux_imgs, src_aux_lbls = src
                    tar_aux_imgs, tar_aux_lbls = tar
                    tar_aux_mag_logits = self.model(tar_aux_imgs, task)
                    src_aux_mag_logits = self.model(src_aux_imgs, task)
                    tar_aux_loss[task] = self.pixel_loss(
                        tar_aux_mag_logits, tar_aux_lbls)
                    src_aux_loss[task] = self.pixel_loss(
                        src_aux_mag_logits, src_aux_lbls)
                    loss += src_aux_loss[task] * self.config.loss_weight[
                        task]  # todo: magnification weight
                    loss += tar_aux_loss[task] * self.config.loss_weight[task]

            precision1_train, precision2_train = accuracy(src_main_logits,
                                                          src_cls_lbls,
                                                          topk=(1, 2))
            top1.update(precision1_train[0], src_imgs.size(0))
            loss.backward()
            self.optimizer.step()
            losses.update(loss.item(), src_imgs.size(0))

            # measure elapsed time
            batch_time.update(time.time() - t)
            self.start_iter += 1
            if self.start_iter % print_freq == 0:
                printt = ''
                for task_name in self.config.aux_task_names:
                    if task_name == 'domain_classifier':
                        printt = printt + ' | tar_aux_' + task_name + ': {:.3f} |'
                    else:
                        printt = printt + 'src_aux_' + task_name + ': {:.3f} | tar_aux_' + task_name + ': {:.3f}'
                print_string = 'Epoch {:>2} | iter {:>4} | loss:{:.3f} |  acc: {:.3f} | src_main: {:.3f} |' + printt + '{:4.2f} s/it'
                src_aux_loss_all = [
                    loss.item() for loss in src_aux_loss.values()
                ]
                tar_aux_loss_all = [
                    loss.item() for loss in tar_aux_loss.values()
                ]
                self.logger.info(
                    print_string.format(epoch, self.start_iter, losses.avg,
                                        top1.avg, main_loss.avg,
                                        *src_aux_loss_all, *tar_aux_loss_all,
                                        batch_time.avg))
                self.writer.add_scalar('losses/all_loss', losses.avg,
                                       self.start_iter)
                self.writer.add_scalar('losses/src_main_loss', src_main_loss,
                                       self.start_iter)
                for task_name in self.config.aux_task_names:
                    if task_name == 'domain_classifier':
                        # self.writer.add_scalar('losses/src_aux_loss_'+task_name, src_aux_loss[task_name], i_iter)
                        self.writer.add_scalar(
                            'losses/tar_aux_loss_' + task_name,
                            tar_aux_loss[task_name], self.start_iter)
                    else:
                        self.writer.add_scalar(
                            'losses/src_aux_loss_' + task_name,
                            src_aux_loss[task_name], self.start_iter)
                        self.writer.add_scalar(
                            'losses/tar_aux_loss_' + task_name,
                            tar_aux_loss[task_name], self.start_iter)
            self.scheduler.step()
        self.wandb.log({"Train Loss": main_loss.avg})

        # del loss, src_class_loss, src_aux_loss, tar_aux_loss, tar_entropy_loss
        # del src_aux_logits, src_class_logits
        # del tar_aux_logits, tar_class_logits

    def train(self, src_loader, tar_loader, val_loader, test_loader):
        num_batches = len(src_loader['main_task'])
        print_freq = max(num_batches // self.config.training_num_print_epoch,
                         1)
        start_epoch = self.start_iter // num_batches
        num_epochs = self.config.num_epochs
        for epoch in range(start_epoch, num_epochs):
            if len(self.config.task_names) == 1:
                self.train_epoch_main_task(src_loader, tar_loader, epoch,
                                           print_freq)
            else:
                self.train_epoch_all_tasks(src_loader, tar_loader, epoch,
                                           print_freq)
            self.logger.info('learning rate: %f ' % get_lr(self.optimizer))
            # validation
            self.save(self.config.model_dir, 'last')

            if val_loader is not None:
                self.logger.info('validating...')
                class_acc, AUC = self.test(val_loader)
                # self.writer.add_scalar('val/aux_acc', class_acc, i_iter)
                self.writer.add_scalar('val/class_acc', class_acc,
                                       self.start_iter)
                if class_acc > self.best_acc:
                    self.best_acc = class_acc
                    self.save(self.config.best_model_dir, 'best_acc')
                if AUC > self.best_AUC:
                    self.best_AUC = AUC
                    self.save(self.config.best_model_dir, 'best_AUC')
                    # todo copy current model to best model
                self.logger.info('Best validation accuracy: {:.2f} %'.format(
                    self.best_acc))

            if test_loader is not None:
                self.logger.info('testing...')
                class_acc = self.test(test_loader)
                # self.writer.add_scalar('test/aux_acc', class_acc, i_iter)
                self.writer.add_scalar('test/class_acc', class_acc,
                                       self.start_iter)
                # if class_acc > self.best_acc:
                #     self.best_acc = class_acc
                # todo copy current model to best model
                self.logger.info(
                    'Best testing accuracy: {:.2f} %'.format(class_acc))

        self.logger.info('Best validation accuracy: {:.2f} %'.format(
            self.best_acc))
        self.logger.info('Finished Training.')

    def save(self, path, ext):
        state = {
            "iter": self.start_iter + 1,
            "model_state": self.model.state_dict(),
            "optimizer_state": self.optimizer.state_dict(),
            "scheduler_state": self.scheduler.state_dict(),
            "best_acc": self.best_acc,
        }
        save_path = os.path.join(path, f'model_{ext}.pth')
        self.logger.info('Saving model to %s' % save_path)
        torch.save(state, save_path)

    def load(self, path):
        checkpoint = torch.load(path)
        self.model.load_state_dict(checkpoint['model_state'])
        self.logger.info('Loaded model from: ' + path)

        if self.config.mode == 'train':
            self.model.load_state_dict(checkpoint['model_state'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state'])
            self.scheduler.load_state_dict(checkpoint['scheduler_state'])
            self.start_iter = checkpoint['iter']
            self.best_acc = checkpoint['best_acc']
            self.logger.info('Start iter: %d ' % self.start_iter)

    def test(self, val_loader):
        val_loader_iterator = iter(val_loader)
        num_val_iters = len(val_loader)
        tt = tqdm(range(num_val_iters), total=num_val_iters, desc="Validating")
        loss = AverageMeter()
        kk = 1
        aux_correct = 0
        class_correct = 0
        total = 0
        if self.config.dataset == 'kather':
            soft_labels = np.zeros((1, 9))
        if self.config.dataset == 'oscc' or self.config.dataset == 'cam':
            soft_labels = np.zeros((1, 2))
        true_labels = []
        self.model.eval()
        with torch.no_grad():
            for cur_it in tt:
                data = next(val_loader_iterator)
                data = to_device(data, self.device)
                imgs, cls_lbls = data
                # Get the inputs
                logits = self.model(imgs, 'main_task')
                test_loss = self.class_loss_func(logits, cls_lbls)
                loss.update(test_loss.item(), imgs.size(0))
                if self.config.save_output == True:
                    smax = nn.Softmax(dim=1)
                    smax_out = smax(logits)
                    soft_labels = np.concatenate(
                        (soft_labels, smax_out.cpu().numpy()), axis=0)
                    true_labels = np.append(true_labels,
                                            cls_lbls.cpu().numpy())
                    pred_trh = smax_out.cpu().numpy()[:, 1]
                    pred_trh[pred_trh >= 0.5] = 1
                    pred_trh[pred_trh < 0.5] = 0
                    compare = cls_lbls.cpu().numpy() - pred_trh

                    kk += 1
                _, cls_pred = logits.max(dim=1)

                class_correct += torch.sum(cls_pred == cls_lbls)
                total += imgs.size(0)

            tt.close()
        self.wandb.log({"Test Loss": loss.avg})
        # if self.config.save_output == True:
        soft_labels = soft_labels[1:, :]
        if self.config.dataset == 'oscc' or self.config.dataset == 'cam':
            AUC = calculate_stat(soft_labels,
                                 true_labels,
                                 2,
                                 self.config.class_names,
                                 type='binary',
                                 thresh=0.5)
        if self.config.dataset == 'kather':
            AUC = calculate_stat(soft_labels,
                                 true_labels,
                                 9,
                                 self.config.class_names,
                                 type='multi',
                                 thresh=0.5)
        class_acc = 100 * float(class_correct) / total
        self.logger.info('class_acc: {:.2f} %'.format(class_acc))
        self.wandb.log({"Test acc": class_acc, "Test AUC": 100 * AUC})
        return class_acc, AUC
def main():
    # Views the training images and displays the distance on anchor-negative and anchor-positive
    # print the experiment configuration
    print('\nCurrent time is \33[91m{}\33[0m'.format(str(time.asctime())))
    print('Parsed options: {}'.format(vars(args)))
    print('Number of Classes: {}\n'.format(len(train_dir.speakers)))

    # instantiate
    # model and initialize weights
    # instantiate model and initialize weights
    model_kwargs = {
        'input_dim': args.feat_dim,
        'embedding_size': args.embedding_size,
        'num_classes': len(train_dir.speakers),
        'dropout_p': args.dropout_p
    }

    print('Model options: {}'.format(model_kwargs))

    model = create_model(args.model, **model_kwargs)

    if args.cuda:
        model.cuda()

    start = 0
    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print('=> loading checkpoint {}'.format(args.resume))
            checkpoint = torch.load(args.resume)
            start = checkpoint['epoch']
            checkpoint = torch.load(args.resume)
            filtered = {
                k: v
                for k, v in checkpoint['state_dict'].items()
                if 'num_batches_tracked' not in k
            }
            model.load_state_dict(filtered)
            # optimizer.load_state_dict(checkpoint['optimizer'])
            # scheduler.load_state_dict(checkpoint['scheduler'])
            # criterion.load_state_dict(checkpoint['criterion'])
        else:
            print('=> no checkpoint found at {}'.format(args.resume))

    ce_criterion = nn.CrossEntropyLoss()
    if args.loss_type == 'soft':
        xe_criterion = None
    elif args.loss_type == 'asoft':
        ce_criterion = None
        model.classifier = AngleLinear(in_features=args.embedding_size,
                                       out_features=train_dir.num_spks,
                                       m=args.m)
        xe_criterion = AngleSoftmaxLoss(lambda_min=args.lambda_min,
                                        lambda_max=args.lambda_max)
    elif args.loss_type == 'center':
        xe_criterion = CenterLoss(num_classes=train_dir.num_spks,
                                  feat_dim=args.embedding_size)
    elif args.loss_type == 'amsoft':
        model.classifier = AdditiveMarginLinear(feat_dim=args.embedding_size,
                                                n_classes=train_dir.num_spks)
        xe_criterion = AMSoftmaxLoss(margin=args.margin, s=args.s)

    optimizer = create_optimizer(model.parameters(), args.optimizer,
                                 **opt_kwargs)
    if args.loss_type == 'center':
        optimizer = torch.optim.SGD([{
            'params': xe_criterion.parameters(),
            'lr': args.lr * 5
        }, {
            'params': model.parameters()
        }],
                                    lr=args.lr,
                                    weight_decay=args.weight_decay,
                                    momentum=args.momentum)

    if args.finetune:
        if args.loss_type == 'asoft' or args.loss_type == 'amsoft':
            classifier_params = list(map(id, model.classifier.parameters()))
            rest_params = filter(lambda p: id(p) not in classifier_params,
                                 model.parameters())
            optimizer = torch.optim.SGD(
                [{
                    'params': model.classifier.parameters(),
                    'lr': args.lr * 5
                }, {
                    'params': rest_params
                }],
                lr=args.lr,
                weight_decay=args.weight_decay,
                momentum=args.momentum)

    milestones = args.milestones.split(',')
    milestones = [int(x) for x in milestones]
    milestones.sort()
    # print('Scheduler options: {}'.format(milestones))
    scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=0.1)

    if args.save_init and not args.finetune:
        check_path = '{}/checkpoint_{}.pth'.format(args.check_path, start)
        torch.save(
            {
                'epoch': start,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict()
            }, check_path)

    start += args.start_epoch
    print('Start epoch is : ' + str(start))
    end = args.epochs + 1

    # pdb.set_trace()
    train_loader = torch.utils.data.DataLoader(train_dir,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               **kwargs)
    valid_loader = torch.utils.data.DataLoader(valid_dir,
                                               batch_size=args.batch_size,
                                               shuffle=False,
                                               **kwargs)
    test_loader = torch.utils.data.DataLoader(test_part,
                                              batch_size=args.test_batch_size,
                                              shuffle=False,
                                              **kwargs)

    ce = [ce_criterion, xe_criterion]
    if args.cuda:
        model = model.cuda()
        for i in range(len(ce)):
            if ce[i] != None:
                ce[i] = ce[i].cuda()

    for epoch in range(start, end):
        # pdb.set_trace()
        print('\n\33[1;34m Current \'{}\' learning rate is '.format(
            args.optimizer),
              end='')
        for param_group in optimizer.param_groups:
            print('{:.5f} '.format(param_group['lr']), end='')
        print(' \33[0m')

        train(train_loader, model, optimizer, ce, epoch)
        test(test_loader, valid_loader, model, epoch)

        scheduler.step()
        # break

    writer.close()
Пример #28
0
class Processor():
    """Processor for Skeleton-based Action Recgnition"""
    def __init__(self, arg):
        self.arg = arg
        self.save_arg()
        if arg.phase == 'train':
            # Added control through the command line
            arg.train_feeder_args[
                'debug'] = arg.train_feeder_args['debug'] or self.arg.debug
            logdir = os.path.join(arg.work_dir, 'trainlogs')
            if not arg.train_feeder_args['debug']:
                # logdir = arg.model_saved_name
                if os.path.isdir(logdir):
                    print(f'log_dir {logdir} already exists')
                    if arg.assume_yes:
                        answer = 'y'
                    else:
                        answer = input('delete it? [y]/n:')
                    if answer.lower() in ('y', ''):
                        shutil.rmtree(logdir)
                        print('Dir removed:', logdir)
                    else:
                        print('Dir not removed:', logdir)

                self.train_writer = SummaryWriter(
                    os.path.join(logdir, 'train'), 'train')
                self.val_writer = SummaryWriter(os.path.join(logdir, 'val'),
                                                'val')
            else:
                self.train_writer = SummaryWriter(
                    os.path.join(logdir, 'debug'), 'debug')

        self.load_model()
        self.load_param_groups()
        self.load_optimizer()
        self.load_lr_scheduler()
        self.load_data()

        self.global_step = 0
        self.lr = self.arg.base_lr
        self.best_acc = 0
        self.best_acc_epoch = 0

        if self.arg.half:
            self.print_log('*************************************')
            self.print_log('*** Using Half Precision Training ***')
            self.print_log('*************************************')
            self.model, self.optimizer = apex.amp.initialize(
                self.model,
                self.optimizer,
                opt_level=f'O{self.arg.amp_opt_level}')
            if self.arg.amp_opt_level != 1:
                self.print_log(
                    '[WARN] nn.DataParallel is not yet supported by amp_opt_level != "O1"'
                )

        if type(self.arg.device) is list:
            if len(self.arg.device) > 1:
                self.print_log(
                    f'{len(self.arg.device)} GPUs available, using DataParallel'
                )
                self.model = nn.DataParallel(self.model,
                                             device_ids=self.arg.device,
                                             output_device=self.output_device)

    def load_model(self):
        output_device = self.arg.device[0] if type(
            self.arg.device) is list else self.arg.device
        self.output_device = output_device
        Model = import_class(self.arg.model)

        # Copy model file and main
        shutil.copy2(inspect.getfile(Model), self.arg.work_dir)
        shutil.copy2(os.path.join('.', __file__), self.arg.work_dir)

        self.model = Model(**self.arg.model_args).cuda(output_device)
        self.loss = nn.CrossEntropyLoss().cuda(output_device)
        self.print_log(
            f'Model total number of params: {count_params(self.model)}')

        if self.arg.weights:
            try:
                self.global_step = int(arg.weights[:-3].split('-')[-1])
            except:
                print('Cannot parse global_step from model weights filename')
                self.global_step = 0

            self.print_log(f'Loading weights from {self.arg.weights}')
            if '.pkl' in self.arg.weights:
                with open(self.arg.weights, 'r') as f:
                    weights = pickle.load(f)
            else:
                weights = torch.load(self.arg.weights)

            weights = OrderedDict(
                [[k.split('module.')[-1],
                  v.cuda(output_device)] for k, v in weights.items()])

            for w in self.arg.ignore_weights:
                if weights.pop(w, None) is not None:
                    self.print_log(f'Sucessfully Remove Weights: {w}')
                else:
                    self.print_log(f'Can Not Remove Weights: {w}')

            try:
                self.model.load_state_dict(weights)
            except:
                state = self.model.state_dict()
                diff = list(set(state.keys()).difference(set(weights.keys())))
                self.print_log('Can not find these weights:')
                for d in diff:
                    self.print_log('  ' + d)
                state.update(weights)
                self.model.load_state_dict(state)

    def load_param_groups(self):
        """
        Template function for setting different learning behaviour
        (e.g. LR, weight decay) of different groups of parameters
        """
        self.param_groups = defaultdict(list)

        for name, params in self.model.named_parameters():
            self.param_groups['other'].append(params)

        self.optim_param_groups = {
            'other': {
                'params': self.param_groups['other']
            }
        }

    def load_optimizer(self):
        params = list(self.optim_param_groups.values())
        if self.arg.optimizer == 'SGD':
            self.optimizer = optim.SGD(params,
                                       lr=self.arg.base_lr,
                                       momentum=0.9,
                                       nesterov=self.arg.nesterov,
                                       weight_decay=self.arg.weight_decay)
        elif self.arg.optimizer == 'Adam':
            self.optimizer = optim.Adam(params,
                                        lr=self.arg.base_lr,
                                        weight_decay=self.arg.weight_decay)
        else:
            raise ValueError('Unsupported optimizer: {}'.format(
                self.arg.optimizer))

        # Load optimizer states if any
        if self.arg.checkpoint is not None:
            self.print_log(
                f'Loading optimizer states from: {self.arg.checkpoint}')
            self.optimizer.load_state_dict(
                torch.load(self.arg.checkpoint)['optimizer_states'])
            current_lr = self.optimizer.param_groups[0]['lr']
            self.print_log(f'Starting LR: {current_lr}')
            self.print_log(
                f'Starting WD1: {self.optimizer.param_groups[0]["weight_decay"]}'
            )
            if len(self.optimizer.param_groups) >= 2:
                self.print_log(
                    f'Starting WD2: {self.optimizer.param_groups[1]["weight_decay"]}'
                )

    def load_lr_scheduler(self):
        self.lr_scheduler = MultiStepLR(self.optimizer,
                                        milestones=self.arg.step,
                                        gamma=0.1)
        if self.arg.checkpoint is not None:
            scheduler_states = torch.load(
                self.arg.checkpoint)['lr_scheduler_states']
            self.print_log(
                f'Loading LR scheduler states from: {self.arg.checkpoint}')
            self.lr_scheduler.load_state_dict(scheduler_states)
            self.print_log(
                f'Starting last epoch: {scheduler_states["last_epoch"]}')
            self.print_log(
                f'Loaded milestones: {scheduler_states["last_epoch"]}')

    def load_data(self):
        Feeder = import_class(self.arg.feeder)
        self.data_loader = dict()

        def worker_seed_fn(worker_id):
            # give workers different seeds
            return init_seed(self.arg.seed + worker_id + 1)

        if self.arg.phase == 'train':
            self.data_loader['train'] = torch.utils.data.DataLoader(
                dataset=Feeder(**self.arg.train_feeder_args),
                batch_size=self.arg.batch_size,
                shuffle=True,
                num_workers=self.arg.num_worker,
                drop_last=True,
                worker_init_fn=worker_seed_fn)

        self.data_loader['test'] = torch.utils.data.DataLoader(
            dataset=Feeder(**self.arg.test_feeder_args),
            batch_size=self.arg.test_batch_size,
            shuffle=False,
            num_workers=self.arg.num_worker,
            drop_last=False,
            worker_init_fn=worker_seed_fn)

    def save_arg(self):
        # save arg
        arg_dict = vars(self.arg)
        if not os.path.exists(self.arg.work_dir):
            os.makedirs(self.arg.work_dir)
        with open(os.path.join(self.arg.work_dir, 'config.yaml'), 'w') as f:
            yaml.dump(arg_dict, f)

    def print_time(self):
        localtime = time.asctime(time.localtime(time.time()))
        self.print_log(f'Local current time: {localtime}')

    def print_log(self, s, print_time=True):
        if print_time:
            localtime = time.asctime(time.localtime(time.time()))
            s = f'[ {localtime} ] {s}'
        print(s)
        if self.arg.print_log:
            with open(os.path.join(self.arg.work_dir, 'log.txt'), 'a') as f:
                print(s, file=f)

    def record_time(self):
        self.cur_time = time.time()
        return self.cur_time

    def split_time(self):
        split_time = time.time() - self.cur_time
        self.record_time()
        return split_time

    def save_states(self, epoch, states, out_folder, out_name):
        out_folder_path = os.path.join(self.arg.work_dir, out_folder)
        out_path = os.path.join(out_folder_path, out_name)
        os.makedirs(out_folder_path, exist_ok=True)
        torch.save(states, out_path)

    def save_checkpoint(self, epoch, out_folder='checkpoints'):
        state_dict = {
            'epoch': epoch,
            'optimizer_states': self.optimizer.state_dict(),
            'lr_scheduler_states': self.lr_scheduler.state_dict(),
        }

        checkpoint_name = f'checkpoint-{epoch}-fwbz{self.arg.forward_batch_size}-{int(self.global_step)}.pt'
        self.save_states(epoch, state_dict, out_folder, checkpoint_name)

    def save_weights(self, epoch, out_folder='weights'):
        state_dict = self.model.state_dict()
        weights = OrderedDict([[k.split('module.')[-1],
                                v.cpu()] for k, v in state_dict.items()])

        weights_name = f'weights-{epoch}-{int(self.global_step)}.pt'
        self.save_states(epoch, weights, out_folder, weights_name)

    def train(self, epoch, save_model=False):
        self.model.train()
        loader = self.data_loader['train']
        loss_values = []
        self.train_writer.add_scalar('epoch', epoch + 1, self.global_step)
        self.record_time()
        timer = dict(dataloader=0.001, model=0.001, statistics=0.001)

        current_lr = self.optimizer.param_groups[0]['lr']
        self.print_log(f'Training epoch: {epoch + 1}, LR: {current_lr:.4f}')

        process = tqdm(loader, dynamic_ncols=True)
        for batch_idx, (data, label, index) in enumerate(process):
            self.global_step += 1
            # get data
            with torch.no_grad():
                data = data.float().cuda(self.output_device)
                label = label.long().cuda(self.output_device)
            timer['dataloader'] += self.split_time()

            # backward
            self.optimizer.zero_grad()

            ############## Gradient Accumulation for Smaller Batches ##############
            real_batch_size = self.arg.forward_batch_size
            splits = len(data) // real_batch_size
            assert len(data) % real_batch_size == 0, \
                'Real batch size should be a factor of arg.batch_size!'

            for i in range(splits):
                left = i * real_batch_size
                right = left + real_batch_size
                batch_data, batch_label = data[left:right], label[left:right]

                # forward
                output = self.model(batch_data)
                if isinstance(output, tuple):
                    output, l1 = output
                    l1 = l1.mean()
                else:
                    l1 = 0

                loss = self.loss(output, batch_label) / splits

                if self.arg.half:
                    with apex.amp.scale_loss(loss,
                                             self.optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                loss_values.append(loss.item())
                timer['model'] += self.split_time()

                # Display loss
                process.set_description(
                    f'(BS {real_batch_size}) loss: {loss.item():.4f}')

                value, predict_label = torch.max(output, 1)
                acc = torch.mean((predict_label == batch_label).float())

                self.train_writer.add_scalar('acc', acc, self.global_step)
                self.train_writer.add_scalar('loss',
                                             loss.item() * splits,
                                             self.global_step)
                self.train_writer.add_scalar('loss_l1', l1, self.global_step)

            #####################################

            # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 2)
            self.optimizer.step()

            # statistics
            self.lr = self.optimizer.param_groups[0]['lr']
            self.train_writer.add_scalar('lr', self.lr, self.global_step)
            timer['statistics'] += self.split_time()

            # Delete output/loss after each batch since it may introduce extra mem during scoping
            # https://discuss.pytorch.org/t/gpu-memory-consumption-increases-while-training/2770/3
            del output
            del loss

        # statistics of time consumption and loss
        proportion = {
            k: f'{int(round(v * 100 / sum(timer.values()))):02d}%'
            for k, v in timer.items()
        }

        mean_loss = np.mean(loss_values)
        num_splits = self.arg.batch_size // self.arg.forward_batch_size
        self.print_log(
            f'\tMean training loss: {mean_loss:.4f} (BS {self.arg.batch_size}: {mean_loss * num_splits:.4f}).'
        )
        self.print_log(
            '\tTime consumption: [Data]{dataloader}, [Network]{model}'.format(
                **proportion))

        # PyTorch > 1.2.0: update LR scheduler here with `.step()`
        # and make sure to save the `lr_scheduler.state_dict()` as part of checkpoint
        self.lr_scheduler.step()

        if save_model:
            # save training checkpoint & weights
            self.save_weights(epoch + 1)
            self.save_checkpoint(epoch + 1)

    def eval(self,
             epoch,
             save_score=False,
             loader_name=['test'],
             wrong_file=None,
             result_file=None):
        # Skip evaluation if too early
        if epoch + 1 < self.arg.eval_start:
            return

        if wrong_file is not None:
            f_w = open(wrong_file, 'w')
        if result_file is not None:
            f_r = open(result_file, 'w')
        with torch.no_grad():
            self.model = self.model.cuda(self.output_device)
            self.model.eval()
            self.print_log(f'Eval epoch: {epoch + 1}')
            for ln in loader_name:
                loss_values = []
                score_batches = []
                step = 0
                process = tqdm(self.data_loader[ln], dynamic_ncols=True)
                for batch_idx, (data, label, index) in enumerate(process):
                    data = data.float().cuda(self.output_device)
                    label = label.long().cuda(self.output_device)
                    output = self.model(data)
                    if isinstance(output, tuple):
                        output, l1 = output
                        l1 = l1.mean()
                    else:
                        l1 = 0
                    loss = self.loss(output, label)
                    score_batches.append(output.data.cpu().numpy())
                    loss_values.append(loss.item())

                    _, predict_label = torch.max(output.data, 1)
                    step += 1

                    if wrong_file is not None or result_file is not None:
                        predict = list(predict_label.cpu().numpy())
                        true = list(label.data.cpu().numpy())
                        for i, x in enumerate(predict):
                            if result_file is not None:
                                f_r.write(str(x) + ',' + str(true[i]) + '\n')
                            if x != true[i] and wrong_file is not None:
                                f_w.write(
                                    str(index[i]) + ',' + str(x) + ',' +
                                    str(true[i]) + '\n')

            score = np.concatenate(score_batches)
            loss = np.mean(loss_values)
            accuracy = self.data_loader[ln].dataset.top_k(score, 1)
            if accuracy > self.best_acc:
                self.best_acc = accuracy
                self.best_acc_epoch = epoch + 1

            print('Accuracy: ', accuracy, ' model: ', self.arg.work_dir)
            if self.arg.phase == 'train' and not self.arg.debug:
                self.val_writer.add_scalar('loss', loss, self.global_step)
                self.val_writer.add_scalar('loss_l1', l1, self.global_step)
                self.val_writer.add_scalar('acc', accuracy, self.global_step)

            score_dict = dict(
                zip(self.data_loader[ln].dataset.sample_name, score))
            self.print_log(
                f'\tMean {ln} loss of {len(self.data_loader[ln])} batches: {np.mean(loss_values)}.'
            )
            for k in self.arg.show_topk:
                self.print_log(
                    f'\tTop {k}: {100 * self.data_loader[ln].dataset.top_k(score, k):.2f}%'
                )

            if save_score:
                with open(
                        '{}/epoch{}_{}_score.pkl'.format(
                            self.arg.work_dir, epoch + 1, ln), 'wb') as f:
                    pickle.dump(score_dict, f)

        # Empty cache after evaluation
        torch.cuda.empty_cache()

    def start(self):
        if self.arg.phase == 'train':
            self.print_log(f'Parameters:\n{pprint.pformat(vars(self.arg))}\n')
            self.print_log(
                f'Model total number of params: {count_params(self.model)}')
            self.global_step = self.arg.start_epoch * len(
                self.data_loader['train']) / self.arg.batch_size
            for epoch in range(self.arg.start_epoch, self.arg.num_epoch):
                save_model = ((epoch + 1) % self.arg.save_interval
                              == 0) or (epoch + 1 == self.arg.num_epoch)
                self.train(epoch, save_model=save_model)
                self.eval(epoch,
                          save_score=self.arg.save_score,
                          loader_name=['test'])

            num_params = sum(p.numel() for p in self.model.parameters()
                             if p.requires_grad)
            self.print_log(f'Best accuracy: {self.best_acc}')
            self.print_log(f'Epoch number: {self.best_acc_epoch}')
            self.print_log(f'Model name: {self.arg.work_dir}')
            self.print_log(f'Model total number of params: {num_params}')
            self.print_log(f'Weight decay: {self.arg.weight_decay}')
            self.print_log(f'Base LR: {self.arg.base_lr}')
            self.print_log(f'Batch Size: {self.arg.batch_size}')
            self.print_log(
                f'Forward Batch Size: {self.arg.forward_batch_size}')
            self.print_log(f'Test Batch Size: {self.arg.test_batch_size}')

        elif self.arg.phase == 'test':
            if not self.arg.test_feeder_args['debug']:
                wf = os.path.join(self.arg.work_dir, 'wrong-samples.txt')
                rf = os.path.join(self.arg.work_dir, 'right-samples.txt')
            else:
                wf = rf = None
            if self.arg.weights is None:
                raise ValueError('Please appoint --weights.')

            self.print_log(f'Model:   {self.arg.model}')
            self.print_log(f'Weights: {self.arg.weights}')

            self.eval(epoch=0,
                      save_score=self.arg.save_score,
                      loader_name=['test'],
                      wrong_file=wf,
                      result_file=rf)

            self.print_log('Done.\n')
Пример #29
0
def main():
    global best_RMSE

    lw = utils_func.LossWise(args.api_key, args.losswise_tag, args.epochs - 1)
    # set logger
    log = logger.setup_logger(os.path.join(args.save_path, 'training.log'))
    for key, value in sorted(vars(args).items()):
        log.info(str(key) + ': ' + str(value))

    # set tensorboard
    writer = SummaryWriter(args.save_path + '/tensorboardx')

    # Data Loader
    if args.generate_depth_map:
        TrainImgLoader = None
        import dataloader.KITTI_submission_loader as KITTI_submission_loader
        TestImgLoader = torch.utils.data.DataLoader(
            KITTI_submission_loader.SubmiteDataset(args.datapath,
                                                   args.data_list,
                                                   args.dynamic_bs),
            batch_size=args.bval,
            shuffle=False,
            num_workers=args.workers,
            drop_last=False)
    elif args.dataset == 'kitti':
        train_data, val_data = KITTILoader3D.dataloader(
            args.datapath,
            args.split_train,
            args.split_val,
            kitti2015=args.kitti2015)
        TrainImgLoader = torch.utils.data.DataLoader(
            KITTILoader_dataset3d.myImageFloder(train_data,
                                                True,
                                                kitti2015=args.kitti2015,
                                                dynamic_bs=args.dynamic_bs),
            batch_size=args.btrain,
            shuffle=True,
            num_workers=8,
            drop_last=False,
            pin_memory=True)
        TestImgLoader = torch.utils.data.DataLoader(
            KITTILoader_dataset3d.myImageFloder(val_data,
                                                False,
                                                kitti2015=args.kitti2015,
                                                dynamic_bs=args.dynamic_bs),
            batch_size=args.bval,
            shuffle=False,
            num_workers=8,
            drop_last=False,
            pin_memory=True)
    else:
        train_data, val_data = listflowfile.dataloader(args.datapath)
        TrainImgLoader = torch.utils.data.DataLoader(
            SceneFlowLoader.myImageFloder(train_data,
                                          True,
                                          calib=args.calib_value),
            batch_size=args.btrain,
            shuffle=True,
            num_workers=8,
            drop_last=False)
        TestImgLoader = torch.utils.data.DataLoader(
            SceneFlowLoader.myImageFloder(val_data,
                                          False,
                                          calib=args.calib_value),
            batch_size=args.bval,
            shuffle=False,
            num_workers=8,
            drop_last=False)

    # Load Model
    if args.data_type == 'disparity':
        model = disp_models.__dict__[args.arch](maxdisp=args.maxdisp)
    elif args.data_type == 'depth':
        model = models.__dict__[args.arch](maxdepth=args.maxdepth,
                                           maxdisp=args.maxdisp,
                                           down=args.down,
                                           scale=args.scale)
    else:
        log.info('Model is not implemented')
        assert False

    # Number of parameters
    log.info('Number of model parameters: {}'.format(
        sum([p.data.nelement() for p in model.parameters()])))
    model = nn.DataParallel(model).cuda()
    torch.backends.cudnn.benchmark = True

    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999))
    scheduler = MultiStepLR(optimizer,
                            milestones=args.lr_stepsize,
                            gamma=args.lr_gamma)

    if args.pretrain:
        if os.path.isfile(args.pretrain):
            log.info("=> loading pretrain '{}'".format(args.pretrain))
            checkpoint = torch.load(args.pretrain)
            model.load_state_dict(checkpoint['state_dict'], strict=False)
        else:
            log.info('[Attention]: Do not find checkpoint {}'.format(
                args.pretrain))

    if args.resume:
        if os.path.isfile(args.resume):
            log.info("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['state_dict'])
            args.start_epoch = checkpoint['epoch']
            optimizer.load_state_dict(checkpoint['optimizer'])
            best_RMSE = checkpoint['best_RMSE']
            scheduler.load_state_dict(checkpoint['scheduler'])
            log.info("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            log.info('[Attention]: Do not find checkpoint {}'.format(
                args.resume))

    if args.generate_depth_map:
        os.makedirs(args.save_path + '/depth_maps/' + args.data_tag,
                    exist_ok=True)

        tqdm_eval_loader = tqdm(TestImgLoader, total=len(TestImgLoader))
        for batch_idx, (imgL_crop, imgR_crop, calib, H, W,
                        filename) in enumerate(tqdm_eval_loader):
            pred_disp = inference(imgL_crop, imgR_crop, calib, model)
            for idx, name in enumerate(filename):
                np.save(
                    args.save_path + '/depth_maps/' + args.data_tag + '/' +
                    name, pred_disp[idx][-H[idx]:, :W[idx]])
        import sys
        sys.exit()

    # evaluation
    if args.evaluate:
        evaluate_metric = utils_func.Metric()
        ## training ##
        for batch_idx, (imgL_crop, imgR_crop, disp_crop_L,
                        calib) in enumerate(TestImgLoader):
            start_time = time.time()
            test(imgL_crop, imgR_crop, disp_crop_L, calib, evaluate_metric,
                 optimizer, model)

            log.info(
                evaluate_metric.print(batch_idx, 'EVALUATE') +
                ' Time:{:.3f}'.format(time.time() - start_time))
        import sys
        sys.exit()

    for epoch in range(args.start_epoch, args.epochs):
        scheduler.step()

        ## training ##
        train_metric = utils_func.Metric()
        tqdm_train_loader = tqdm(TrainImgLoader, total=len(TrainImgLoader))
        for batch_idx, (imgL_crop, imgR_crop, disp_crop_L,
                        calib) in enumerate(tqdm_train_loader):
            # start_time = time.time()
            train(imgL_crop, imgR_crop, disp_crop_L, calib, train_metric,
                  optimizer, model, epoch)
            # log.info(train_metric.print(batch_idx, 'TRAIN') + ' Time:{:.3f}'.format(time.time() - start_time))
        log.info(train_metric.print(0, 'TRAIN Epoch' + str(epoch)))
        train_metric.tensorboard(writer, epoch, token='TRAIN')
        lw.update(train_metric.get_info(), epoch, 'Train')

        ## testing ##
        is_best = False
        if epoch == 0 or ((epoch + 1) % args.eval_interval) == 0:
            test_metric = utils_func.Metric()
            tqdm_test_loader = tqdm(TestImgLoader, total=len(TestImgLoader))
            for batch_idx, (imgL_crop, imgR_crop, disp_crop_L,
                            calib) in enumerate(tqdm_test_loader):
                # start_time = time.time()
                test(imgL_crop, imgR_crop, disp_crop_L, calib, test_metric,
                     optimizer, model)
                # log.info(test_metric.print(batch_idx, 'TEST') + ' Time:{:.3f}'.format(time.time() - start_time))
            log.info(test_metric.print(0, 'TEST Epoch' + str(epoch)))
            test_metric.tensorboard(writer, epoch, token='TEST')
            lw.update(test_metric.get_info(), epoch, 'Test')

            # SAVE
            is_best = test_metric.RMSELIs.avg < best_RMSE
            best_RMSE = min(test_metric.RMSELIs.avg, best_RMSE)
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_RMSE': best_RMSE,
                'scheduler': scheduler.state_dict(),
                'optimizer': optimizer.state_dict(),
            },
            is_best,
            epoch,
            folder=args.save_path)
    lw.done()
Пример #30
0
def train(args):
    use_gpu = torch.cuda.is_available()
    num_gpu = list(range(torch.cuda.device_count()))
    assert use_gpu, "Please use gpus."

    logger = get_logger(name=args.shortname)
    display_args(args, logger)

    # create dir for saving
    args.saverootpath = osp.abspath(args.saverootpath)
    savepath = osp.join(args.saverootpath, args.run_name)
    if not osp.exists(savepath):
        os.makedirs(savepath)

    train_file = os.path.join(args.image_sets,
                              "{}.txt".format(args.train_dataset))
    n_features = 35 if args.no_reflex else 36
    if args.pixor_fusion:
        if args.e2e:
            train_data = KittiDataset_Fusion_stereo(
                txt_file=train_file,
                flip_rate=args.flip_rate,
                lidar_dir=args.eval_lidar_dir,
                label_dir=args.eval_label_dir,
                calib_dir=args.eval_calib_dir,
                image_dir=args.eval_image_dir,
                root_dir=args.root_dir,
                only_feature=args.no_cal_loss,
                split=args.split,
                image_downscale=args.image_downscale,
                crop_height=args.crop_height,
                random_shift_scale=args.random_shift_scale)
        else:
            train_data = KittiDataset_Fusion(
                txt_file=train_file,
                flip_rate=args.flip_rate,
                lidar_dir=args.train_lidar_dir,
                label_dir=args.train_label_dir,
                calib_dir=args.train_calib_dir,
                n_features=n_features,
                random_shift_scale=args.random_shift_scale,
                root_dir=args.root_dir,
                image_downscale=args.image_downscale)

    else:
        train_data = KittiDataset(txt_file=train_file,
                                  flip_rate=args.flip_rate,
                                  lidar_dir=args.train_lidar_dir,
                                  label_dir=args.train_label_dir,
                                  calib_dir=args.train_calib_dir,
                                  image_dir=args.train_image_dir,
                                  n_features=n_features,
                                  random_shift_scale=args.random_shift_scale,
                                  root_dir=args.root_dir)
    train_loader = DataLoader(train_data,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=8)

    eval_data, eval_loader = get_eval_dataset(args)

    if args.pixor_fusion:
        pixor = PixorNet_Fusion(n_features,
                                groupnorm=args.groupnorm,
                                resnet_type=args.resnet_type,
                                image_downscale=args.image_downscale,
                                resnet_chls=args.resnet_chls)
    else:
        pixor = PixorNet(n_features, groupnorm=args.groupnorm)

    ts = time.time()
    pixor = pixor.cuda()
    pixor = nn.DataParallel(pixor, device_ids=num_gpu)

    class_criterion = nn.BCELoss(reduction='none')
    reg_criterion = nn.SmoothL1Loss(reduction='none')

    if args.opt_method == 'RMSprop':
        optimizer = optim.RMSprop(pixor.parameters(),
                                  lr=args.lr,
                                  momentum=args.momentum,
                                  weight_decay=args.weight_decay)
    else:
        raise NotImplementedError()

    depth_model = PSMNet(maxdepth=80, maxdisp=192, down=args.depth_down)
    depth_model = nn.DataParallel(depth_model).cuda()
    # torch.backends.cudnn.benchmark = True
    depth_optimizer = optim.Adam(depth_model.parameters(),
                                 lr=args.depth_lr,
                                 betas=(0.9, 0.999))
    grid_3D_extended = get_3D_global_grid_extended(700, 800, 35).cuda().float()

    if args.depth_pretrain:
        if os.path.isfile(args.depth_pretrain):
            logger.info("=> loading depth pretrain '{}'".format(
                args.depth_pretrain))
            checkpoint = torch.load(args.depth_pretrain)
            depth_model.load_state_dict(checkpoint['state_dict'])
            depth_optimizer.load_state_dict(checkpoint['optimizer'])
        else:
            logger.info('[Attention]: Do not find checkpoint {}'.format(
                args.depth_pretrain))

    depth_scheduler = MultiStepLR(depth_optimizer,
                                  milestones=args.depth_lr_stepsize,
                                  gamma=args.depth_lr_gamma)

    if args.pixor_pretrain:
        if os.path.isfile(args.pixor_pretrain):
            logger.info("=> loading depth pretrain '{}'".format(
                args.pixor_pretrain))
            checkpoint = torch.load(args.pixor_pretrain)
            pixor.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            optimizer.param_groups[0]['lr'] *= 10

        else:
            logger.info('[Attention]: Do not find checkpoint {}'.format(
                args.pixor_pretrain))

    scheduler = lr_scheduler.MultiStepLR(optimizer,
                                         milestones=args.lr_milestones,
                                         gamma=args.gamma)

    if args.resume:
        logger.info("Resuming...")
        checkpoint_path = osp.join(savepath, args.checkpoint)
        if os.path.isfile(checkpoint_path):
            logger.info("Loading checkpoint '{}'".format(checkpoint_path))
            checkpoint = torch.load(checkpoint_path)
            pixor.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            scheduler.load_state_dict(checkpoint['scheduler'])
            depth_model.load_state_dict(checkpoint['depth_state_dict'])
            depth_optimizer.load_state_dict(checkpoint['depth_optimizer'])
            depth_scheduler.load_state_dict(checkpoint['depth_scheduler'])
            start_epoch = checkpoint['epoch'] + 1
            logger.info(
                "Resumed successfully from epoch {}.".format(start_epoch))
        else:
            logger.warning("Model {} not found. "
                           "Train from scratch".format(checkpoint_path))
            start_epoch = 0
    else:
        start_epoch = 0

    class_criterion = class_criterion.cuda()
    reg_criterion = reg_criterion.cuda()

    processes = []
    last_eval_epoches = []
    for epoch in range(start_epoch, args.epochs):
        pixor.train()
        depth_model.train()
        scheduler.step()
        depth_scheduler.step()
        ts = time.time()
        logger.info("Start epoch {}, depth lr {:.6f} pixor lr {:.7f}".format(
            epoch, depth_optimizer.param_groups[0]['lr'],
            optimizer.param_groups[0]['lr']))

        avg_class_loss = AverageMeter()
        avg_reg_loss = AverageMeter()
        avg_total_loss = AverageMeter()

        train_metric = utils_func.Metric()

        for iteration, batch in enumerate(train_loader):

            if args.pixor_fusion:
                if not args.e2e:
                    inputs = batch['X'].cuda()
                else:
                    imgL = batch['imgL'].cuda()
                    imgR = batch['imgR'].cuda()
                    f = batch['f']
                    depth_map = batch['depth_map'].cuda()
                    idxx = batch['idx']
                    h_shift = batch['h_shift']
                    ori_shape = batch['ori_shape']
                    a_shift = batch['a_shift']
                    flip = batch['flip']
                images = batch['image'].cuda()
                img_index = batch['img_index'].cuda()
                bev_index = batch['bev_index'].cuda()
            else:
                inputs = batch['X'].cuda()
            class_labels = batch['cl'].cuda()
            reg_labels = batch['rl'].cuda()

            if args.pixor_fusion:
                if not args.e2e:
                    class_outs, reg_outs = pixor(inputs, images, img_index,
                                                 bev_index)
                else:
                    depth_loss, depth_map = forward_depth_model(
                        imgL, imgR, depth_map, f, train_metric, depth_model)
                    inputs = []
                    for i in range(depth_map.shape[0]):
                        calib = utils_func.torchCalib(
                            train_data.dataset.get_calibration(idxx[i]),
                            h_shift[i])
                        H, W = ori_shape[0][i], ori_shape[1][i]
                        depth = depth_map[i][-H:, :W]
                        ptc = depth_to_pcl(calib, depth, max_high=1.)
                        ptc = calib.lidar_to_rect(ptc[:, 0:3])

                        if torch.abs(a_shift[i]).item() > 1e-6:
                            roty = utils_func.roty_pth(a_shift[i]).cuda()
                            ptc = torch.mm(ptc, roty.t())
                        voxel = gen_feature_diffused_tensor(
                            ptc,
                            700,
                            800,
                            grid_3D_extended,
                            diffused=args.diffused)

                        if flip[i] > 0:
                            voxel = torch.flip(voxel, [2])

                        inputs.append(voxel)
                    inputs = torch.stack(inputs)
                    class_outs, reg_outs = pixor(inputs, images, img_index,
                                                 bev_index)
            else:
                class_outs, reg_outs = pixor(inputs)
            class_outs = class_outs.squeeze(1)
            class_loss, reg_loss, loss = \
                compute_loss(epoch, class_outs, reg_outs,
                    class_labels, reg_labels, class_criterion,
                    reg_criterion, args)
            avg_class_loss.update(class_loss.item())
            avg_reg_loss.update(reg_loss.item() \
                if not isinstance(reg_loss, int) else reg_loss)
            avg_total_loss.update(loss.item())

            optimizer.zero_grad()
            depth_optimizer.zero_grad()
            loss = depth_loss + 0.1 * loss
            loss.backward()
            optimizer.step()
            depth_optimizer.step()

            if not isinstance(reg_loss, int):
                reg_loss = reg_loss.item()

            if iteration % args.logevery == 0:
                logger.info("epoch {:d}, iter {:d}, class_loss: {:.5f},"
                            " reg_loss: {:.5f}, loss: {:.5f}".format(
                                epoch, iteration, avg_class_loss.avg,
                                avg_reg_loss.avg, avg_total_loss.avg))

                logger.info(train_metric.print(epoch, iteration))

        logger.info("Finish epoch {}, time elapsed {:.3f} s".format(
            epoch,
            time.time() - ts))

        if epoch % args.eval_every_epoch == 0 and epoch >= args.start_eval:
            logger.info("Evaluation begins at epoch {}".format(epoch))
            evaluate(eval_data,
                     eval_loader,
                     pixor,
                     depth_model,
                     args.batch_size,
                     gpu=use_gpu,
                     logger=logger,
                     args=args,
                     epoch=epoch,
                     processes=processes,
                     grid_3D_extended=grid_3D_extended)
            if args.run_official_evaluate:
                last_eval_epoches.append((epoch, 7))
                last_eval_epoches.append((epoch, 5))

        if len(last_eval_epoches) > 0:
            for e, iou in last_eval_epoches[:]:
                predicted_results = osp.join(args.saverootpath, args.run_name,
                                             'predicted_label_{}'.format(e),
                                             'outputs_{:02d}.txt'.format(iou))
                if osp.exists(predicted_results):
                    with open(predicted_results, 'r') as f:
                        for line in f.readlines():
                            if line.startswith('car_detection_ground AP'):
                                results = [
                                    float(num)
                                    for num in line.strip('\n').split(' ')[-3:]
                                ]
                                last_eval_epoches.remove((e, iou))

        if epoch % args.save_every == 0:
            saveto = osp.join(savepath, "checkpoint_{}.pth.tar".format(epoch))
            torch.save(
                {
                    'state_dict': pixor.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'scheduler': scheduler.state_dict(),
                    'depth_state_dict': depth_model.state_dict(),
                    'depth_optimizer': depth_optimizer.state_dict(),
                    'depth_scheduler': depth_scheduler.state_dict(),
                    'epoch': epoch
                }, saveto)
            logger.info("model saved to {}".format(saveto))
            symlink_force(saveto, osp.join(savepath, "checkpoint.pth.tar"))

    for p in processes:
        if p.wait() != 0:
            logger.warning("There was an error")

    if len(last_eval_epoches) > 0:
        for e, iou in last_eval_epoches[:]:
            predicted_results = osp.join(args.saverootpath, args.run_name,
                                         'predicted_label_{}'.format(e),
                                         'outputs_{:02d}.txt'.format(iou))
            if osp.exists(predicted_results):
                with open(predicted_results, 'r') as f:
                    for line in f.readlines():
                        if line.startswith('car_detection_ground AP'):
                            results = [
                                float(num)
                                for num in line.strip('\n').split(' ')[-3:]
                            ]
                            last_eval_epoches.remove((e, iou))