Beispiel #1
0
def main():
    para = params_setup()
    logging_config_setup(para)

    logging.info('Creating graph')
    graph, model, data_generator = create_graph(para)

    with tf.Session(config=config_setup(), graph=graph) as sess:
        sess.run(tf.global_variables_initializer())
        logging.info('Loading weights')
        load_weights(para, sess, model)
        print_num_of_trainable_parameters()

        try:
            if para.mode == 'train':
                logging.info('Started training')
                train(para, sess, model, data_generator)
                if para.save_final_model_path != '':
                    save_weights(sess, model, para.save_final_model_path)
            elif para.mode == 'validation':
                logging.info('Started validation')
                test(para, sess, model, data_generator)
            elif para.mode == 'test':
                logging.info('Started testing')
                test(para, sess, model, data_generator)
            elif para.mode == 'predict':
                logging.info('Predicting')
                predict(para, sess, model, data_generator,
                        './data/solar-energy3/solar_predict.txt', para.samples)

        except KeyboardInterrupt:
            print('KeyboardInterrupt')
        finally:
            print('Stop')
Beispiel #2
0
def main():
    controller = Controller()

    kwargs = {"num_workers": 1, "pin_memory": True} if FLAGS.CUDA else {}
    train_loader = DataLoader(
        datasets.Omniglot(
            "data",
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.ToTensor(),
                # transforms.Normalize(mean, std) should add for better performance, + other transforms
            ])),
        batch_sizer=FLAGS.BATCH_SIZE,
        shuffle=True,
        **kwargs)

    test_loader = DataLoader(
        datasets.Omniglot(
            "data",
            train=False,
            transform=transforms.Compose([
                transforms.ToTensor(),
                # transforms.Normalize(mean, std) should add for better performance, + other transforms
            ])),
        batch_sizer=FLAGS.BATCH_SIZE,
        shuffle=True,
        **kwargs)

    if FLAGS.TRAIN:
        train(controller, train_loader)
    else:
        test(controller, test_loader)
def run(args):
    if args.step1:  # training VAE
        step1(args)
    elif args.step2:  #training sentiment classifier
        step2(args)
    elif args.step3:  # training transfer network
        step3(args)
    elif args.chat:
        chat(args)
    elif args.test:  # test
        test(args)
def validate(model, val_data_loader, writer, curr_iter, config, transform_data_fn):
  v_loss, v_score, v_mAP, v_mIoU = test(model, val_data_loader, config, transform_data_fn)
  writer.add_scalar('validation/mIoU', v_mIoU, curr_iter)
  writer.add_scalar('validation/loss', v_loss, curr_iter)
  writer.add_scalar('validation/precision_at_1', v_score, curr_iter)

  return v_mIoU
Beispiel #5
0
def validate(model,
             val_data_loader,
             writer,
             curr_iter,
             config,
             transform_data_fn=None):
    v_loss, v_score, v_mAP, v_mIoU = test(model, val_data_loader, config)
    return v_mIoU
Beispiel #6
0
def main():
    para = params_setup()
    logging_config_setup(para)

    print("Creating graph...")
    graph, model, data_generator = create_graph(para)
    print("Done creating graph.")

    with tf.Session(config=config_setup(), graph=graph) as sess:
        sess.run(tf.global_variables_initializer())
        print("Loading weights...")
        load_weights(para, sess, model)
        print_num_of_trainable_parameters()

        # PRINT NAMES OF TENSORS THAT ARE ALPHAS
        # example name: "model/rnn/cond/rnn/multi_rnn_cell/cell_0/cell_0/temporal_pattern_attention_cell_wrapper/attention/Sigmoid:0"
        # for item in [n.name for n in tf.get_default_graph().as_graph_def().node
        #              if (n.name.find("temporal_pattern_attention_cell_wrapper/attention")!=-1 and
        #                  n.name.find("Sigmoid")!=-1)]:
        #     print(item)

        # Print names of ops
        # for op in tf.get_default_graph().get_operations():
        #     if(op.name.find("ben_multiply")!=-1):
        #         print(str(op.name))

        # PRINT REG KERNEL AND BIAS
        # reg_weights = [v for v in tf.global_variables() if v.name == "model/dense_2/kernel:0"][0]
        # reg_bias = [v for v in tf.global_variables() if v.name == "model/dense_2/bias:0"][0]
        # print("Reg Weights:", sess.run(reg_weights))
        # print("Reg Bias:", sess.run(reg_bias) * data_generator.scale[0])

        try:
            if para.mode == 'train':
                train(para, sess, model, data_generator)
            elif para.mode == 'test':
                print("Evaluating model...")
                test(para, sess, model, data_generator)

        except KeyboardInterrupt:
            print('KeyboardInterrupt')
        finally:
            print('Stop')
Beispiel #7
0
def main():
    para = params_setup()
    logging_config_setup(para)

    graph, model, data_generator = create_graph(para)

    with tf.Session(config=config_setup(), graph=graph) as sess:
        sess.run(tf.global_variables_initializer())
        load_weights(para, sess, model)
        print_num_of_trainable_parameters()

        try:
            if para.mode == 'train':
                train(para, sess, model, data_generator)
            elif para.mode == 'test':
                test(para, sess, model, data_generator)

        except KeyboardInterrupt:
            print('KeyboardInterrupt')
        finally:
            print('Stop')
Beispiel #8
0
def validate(pipeline_model, data_loader, config, writer, curr_iter, best_val,
             best_val_iter, optimizer, epoch):
    val_dict = test(pipeline_model, data_loader, config)
    update_writer(writer, val_dict, curr_iter, 'validation')
    curr_val = pipeline_model.get_metric(val_dict)

    if curr_val > best_val:
        best_val = curr_val
        best_val_iter = curr_iter
        checkpoint(pipeline_model, optimizer, epoch, curr_iter, config,
                   best_val, best_val_iter, 'best_val')
    logging.info(
        f'Current best {pipeline_model.TARGET_METRIC}: {best_val:.3f} at iter {best_val_iter}'
    )

    # Recover back
    pipeline_model.train()

    return best_val, best_val_iter
Beispiel #9
0
def main():
    para = params_setup()
    logging_config_setup(para)

    graph, model, data_generator = create_graph(para)

    with tf.Session(config=config_setup(), graph=graph) as sess:
        sess.run(tf.global_variables_initializer())
        load_weights(para, sess, model)
        print_num_of_trainable_parameters()

        try:
            if para.mode == 'train':
                train(para, sess, model, data_generator)
            elif para.mode == 'test':
                obs, predicted = test(para, sess, model, data_generator)
                obs = obs * data_generator.scale + data_generator.min_value
                predicted = predicted * data_generator.scale + data_generator.min_value
                print("MSE: ", mean_squared_error(obs[:, 0], predicted[:, 0]))
                idx = pd.DatetimeIndex(start='2016-10-16',
                                       end='2018-11-04',
                                       freq='W')
                obs_df = pd.DataFrame(data=obs[:, 0],
                                      columns=['Observed'],
                                      index=idx)
                pred_df = pd.DataFrame(data=predicted[:, 0],
                                       columns=['Predicted'],
                                       index=idx)
                df = pd.concat([obs_df, pred_df], axis=1)
                df.plot()
                plt.show()

        except KeyboardInterrupt:
            print('KeyboardInterrupt')
        finally:
            print('Stop')
Beispiel #10
0
def train(model, data_loader, val_data_loader, config, transform_data_fn=None):

    device = get_torch_device(config.is_cuda)
    # Set up the train flag for batch normalization
    model.train()

    # Configuration
    data_timer, iter_timer = Timer(), Timer()
    data_time_avg, iter_time_avg = AverageMeter(), AverageMeter()
    regs, losses, scores = AverageMeter(), AverageMeter(), AverageMeter()

    optimizer = initialize_optimizer(model.parameters(), config)
    scheduler = initialize_scheduler(optimizer, config)
    criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label)

    # Train the network
    logging.info('===> Start training')
    best_val_miou, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True

    if config.resume:
        # Test loaded ckpt first
        v_loss, v_score, v_mAP, v_mIoU = test(model, val_data_loader, config)

        checkpoint_fn = config.resume + '/weights.pth'
        if osp.isfile(checkpoint_fn):
            logging.info("=> loading checkpoint '{}'".format(checkpoint_fn))
            state = torch.load(checkpoint_fn)
            curr_iter = state['iteration'] + 1
            epoch = state['epoch']
            # we skip attention maps because the shape won't match because voxel number is different
            # e.g. copyting a param with shape (23385, 8, 4) to (43529, 8, 4)
            d = {
                k: v
                for k, v in state['state_dict'].items() if 'map' not in k
            }
            # handle those attn maps we don't load from saved dict
            for k in model.state_dict().keys():
                if k in d.keys(): continue
                d[k] = model.state_dict()[k]
            model.load_state_dict(d)
            if config.resume_optimizer:
                scheduler = initialize_scheduler(optimizer,
                                                 config,
                                                 last_step=curr_iter)
                optimizer.load_state_dict(state['optimizer'])
            if 'best_val' in state:
                best_val_miou = state['best_val']
                best_val_iter = state['best_val_iter']
            logging.info("=> loaded checkpoint '{}' (epoch {})".format(
                checkpoint_fn, state['epoch']))
        else:
            raise ValueError(
                "=> no checkpoint found at '{}'".format(checkpoint_fn))

    data_iter = data_loader.__iter__()
    if config.dataset == "SemanticKITTI":
        num_class = 19
        config.normalize_color = False
        config.xyz_input = False
        val_freq_ = config.val_freq
        config.val_freq = config.val_freq * 10
    elif config.dataset == "S3DIS":
        num_class = 13
        config.normalize_color = False
        config.xyz_input = False
        val_freq_ = config.val_freq
        config.val_freq = config.val_freq
    elif config.dataset == "Nuscenes":
        num_class = 16
        config.normalize_color = False
        config.xyz_input = False
        val_freq_ = config.val_freq
        config.val_freq = config.val_freq * 50
    else:
        num_class = 20
        val_freq_ = config.val_freq

    while is_training:
        total_correct_class = torch.zeros(num_class, device=device)
        total_iou_deno_class = torch.zeros(num_class, device=device)

        for iteration in range(len(data_loader) // config.iter_size):
            optimizer.zero_grad()
            data_time, batch_loss = 0, 0
            iter_timer.tic()

            if curr_iter >= config.max_iter:
                # if curr_iter >= max(config.max_iter, config.epochs*(len(data_loader) // config.iter_size):
                is_training = False
                break
            elif curr_iter >= config.max_iter * (2 / 3):
                config.val_freq = val_freq_ * 2  # valid more freq on lower half

            for sub_iter in range(config.iter_size):
                # Get training data
                data_timer.tic()
                pointcloud = None

                if config.return_transformation:
                    coords, input, target, _, _, pointcloud, transformation, _ = data_iter.next(
                    )
                else:
                    coords, input, target, _, _, _ = data_iter.next(
                    )  # ignore unique_map and inverse_map

                if config.use_aux:
                    assert target.shape[1] == 2
                    aux = target[:, 1]
                    target = target[:, 0]
                else:
                    aux = None

                # For some networks, making the network invariant to even, odd coords is important
                coords[:, 1:] += (torch.rand(3) * 100).type_as(coords)

                # Preprocess input
                if config.normalize_color:
                    input[:, :3] = input[:, :3] / input[:, :3].max() - 0.5
                    coords_norm = coords[:, 1:] / coords[:, 1:].max() - 0.5

                # cat xyz into the rgb feature
                if config.xyz_input:
                    input = torch.cat([coords_norm, input], dim=1)
                sinput = SparseTensor(input, coords, device=device)
                starget = SparseTensor(
                    target.unsqueeze(-1).float(),
                    coordinate_map_key=sinput.coordinate_map_key,
                    coordinate_manager=sinput.coordinate_manager,
                    device=device
                )  # must share the same coord-manager to align for sinput

                data_time += data_timer.toc(False)
                # model.initialize_coords(*init_args)

                # d = {}
                # d['c'] = sinput.C
                # d['l'] = starget.F
                # torch.save('./plot/test-label.pth')
                # import ipdb; ipdb.set_trace()

                # Set up profiler
                # memory_profiler = CUDAMemoryProfiler(
                # [model, criterion],
                # filename="cuda_memory.profile"
                # )
                # sys.settrace(memory_profiler)
                # threading.settrace(memory_profiler)

                # with torch.autograd.profiler.profile(enabled=True, use_cuda=True, record_shapes=False, profile_memory=True) as prof0:
                if aux is not None:
                    soutput = model(sinput, aux)
                elif config.enable_point_branch:
                    soutput = model(sinput,
                                    iter_=curr_iter / config.max_iter,
                                    enable_point_branch=True)
                else:
                    # label-aux, feed it in as additional reg
                    soutput = model(
                        sinput, iter_=curr_iter / config.max_iter, aux=starget
                    )  # feed in the progress of training for annealing inside the model

                # The output of the network is not sorted
                target = target.view(-1).long().to(device)
                loss = criterion(soutput.F, target.long())

                # ====== other loss regs =====
                if hasattr(model, 'block1'):
                    cur_loss = torch.tensor([0.], device=device)

                    if hasattr(model.block1[0], 'vq_loss'):
                        if model.block1[0].vq_loss is not None:
                            cur_loss = torch.tensor([0.], device=device)
                            for n, m in model.named_children():
                                if 'block' in n:
                                    cur_loss += m[
                                        0].vq_loss  # m is the nn.Sequential obj, m[0] is the TRBlock
                            logging.info(
                                'Cur Loss: {}, Cur vq_loss: {}'.format(
                                    loss, cur_loss))
                            loss += cur_loss

                    if hasattr(model.block1[0], 'diverse_loss'):
                        if model.block1[0].diverse_loss is not None:
                            cur_loss = torch.tensor([0.], device=device)
                            for n, m in model.named_children():
                                if 'block' in n:
                                    cur_loss += m[
                                        0].diverse_loss  # m is the nn.Sequential obj, m[0] is the TRBlock
                            logging.info(
                                'Cur Loss: {}, Cur diverse _loss: {}'.format(
                                    loss, cur_loss))
                            loss += cur_loss

                    if hasattr(model.block1[0], 'label_reg'):
                        if model.block1[0].label_reg is not None:
                            cur_loss = torch.tensor([0.], device=device)
                            for n, m in model.named_children():
                                if 'block' in n:
                                    cur_loss += m[
                                        0].label_reg  # m is the nn.Sequential obj, m[0] is the TRBlock
                            # logging.info('Cur Loss: {}, Cur diverse _loss: {}'.format(loss, cur_loss))
                            loss += cur_loss

                # Compute and accumulate gradient
                loss /= config.iter_size
                batch_loss += loss.item()
                loss.backward()

                # soutput = model(sinput)

            # Update number of steps
            if not config.use_sam:
                optimizer.step()
            else:
                optimizer.first_step(zero_grad=True)
                soutput = model(sinput,
                                iter_=curr_iter / config.max_iter,
                                aux=starget)
                criterion(soutput.F, target.long()).backward()
                optimizer.second_step(zero_grad=True)

            if config.lr_warmup is None:
                scheduler.step()
            else:
                if curr_iter >= config.lr_warmup:
                    scheduler.step()
                for g in optimizer.param_groups:
                    g['lr'] = config.lr * (iteration + 1) / config.lr_warmup

            # CLEAR CACHE!
            torch.cuda.empty_cache()

            data_time_avg.update(data_time)
            iter_time_avg.update(iter_timer.toc(False))

            pred = get_prediction(data_loader.dataset, soutput.F, target)
            score = precision_at_one(pred, target, ignore_label=-1)

            regs.update(cur_loss.item(), target.size(0))
            losses.update(batch_loss, target.size(0))
            scores.update(score, target.size(0))

            # calc the train-iou
            for l in range(num_class):
                total_correct_class[l] += ((pred == l) & (target == l)).sum()
                total_iou_deno_class[l] += (((pred == l) & (target != -1)) |
                                            (target == l)).sum()

            if curr_iter % config.stat_freq == 0 or curr_iter == 1:
                lrs = ', '.join(
                    ['{:.3e}'.format(x) for x in scheduler.get_lr()])
                IoU = ((total_correct_class) /
                       (total_iou_deno_class + 1e-6)).mean() * 100.
                debug_str = "[{}] ===> Epoch[{}]({}/{}): Loss {:.4f}\tLR: {}\t".format(
                    config.log_dir.split('/')[-2], epoch, curr_iter,
                    len(data_loader) // config.iter_size, losses.avg, lrs)
                debug_str += "Score {:.3f}\tIoU {:.3f}\tData time: {:.4f}, Iter time: {:.4f}".format(
                    scores.avg, IoU.item(), data_time_avg.avg,
                    iter_time_avg.avg)
                if regs.avg > 0:
                    debug_str += "\n Additional Reg Loss {:.3f}".format(
                        regs.avg)
                # print(debug_str)
                logging.info(debug_str)
                # Reset timers
                data_time_avg.reset()
                iter_time_avg.reset()
                # Write logs
                losses.reset()
                scores.reset()

            # Save current status, save before val to prevent occational mem overflow
            if curr_iter % config.save_freq == 0:
                checkpoint(model,
                           optimizer,
                           epoch,
                           curr_iter,
                           config,
                           best_val_miou,
                           best_val_iter,
                           save_inter=True)

            # Validation
            if curr_iter % config.val_freq == 0:
                val_miou = validate(model, val_data_loader, None, curr_iter,
                                    config, transform_data_fn)
                if val_miou > best_val_miou:
                    best_val_miou = val_miou
                    best_val_iter = curr_iter
                    checkpoint(model,
                               optimizer,
                               epoch,
                               curr_iter,
                               config,
                               best_val_miou,
                               best_val_iter,
                               "best_val",
                               save_inter=True)
                logging.info("Current best mIoU: {:.3f} at iter {}".format(
                    best_val_miou, best_val_iter))
                # print("Current best mIoU: {:.3f} at iter {}".format(best_val_miou, best_val_iter))

                # Recover back
                model.train()

            # End of iteration
            curr_iter += 1

        IoU = (total_correct_class) / (total_iou_deno_class + 1e-6)
        logging.info('train point avg class IoU: %f' % ((IoU).mean() * 100.))

        epoch += 1

    # Explicit memory cleanup
    if hasattr(data_iter, 'cleanup'):
        data_iter.cleanup()

    # Save the final model
    checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou,
               best_val_iter)
    v_loss, v_score, v_mAP, val_miou = test(model, val_data_loader, config)
    if val_miou > best_val_miou:
        best_val_miou = val_miou
        best_val_iter = curr_iter
        checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou,
                   best_val_iter, "best_val")
    logging.info("Current best mIoU: {:.3f} at iter {}".format(
        best_val_miou, best_val_iter))
Beispiel #11
0
    #load model from last check point
    print("resuming from last checkpoint...")
    assert os.path.isfile('checkpoint/checkpoint.pth.tar'), "Error: no checkpoint found"
    checkpoint = torch.load("./checkpoint/checkpoint.pth.tar")
    start_epoch =  checkpoint["epoch"]
    net = checkpoint["net"]
    best_acc =  checkpoint["best_acc"]
    optimizer = (checkpoint["optimizer"])
    print("Loaded model from {} , epoch {}".format("./checkpoint/checkpoint.pth.tar", checkpoint["epoch"]))
else:
    print("Building model")
    net = lenet.lenet()
    #net = resnet.ResNet18()

if use_cuda:
    net.cuda()
    net = torch.nn.DataParallel(net, device_ids=range(torch.cuda.device_count()))
    cudnn.benchmark = True

if __name__ == "__main__":

    print("USE CUDA : ", use_cuda)
    for epoch in range(start_epoch, start_epoch+2):
        lr = utils.lr_multiplier(epoch)
        train(epoch=epoch, trainloader=trainloader, net=net, use_cuda=use_cuda,learning_rate=lr)
        test(epoch, testloader=testloader, net=net, use_cuda=use_cuda,learning_rate= lr)

    writer.export_scalars_to_json("./all_scalars.json")
    writer.close()
def main_worker(gpu, ngpus_per_node, config):
    config.gpu = gpu

    #if config.is_cuda and not torch.cuda.is_available():
    #  raise Exception("No GPU found")
    if config.gpu is not None:
        print("Use GPU: {} for training".format(config.gpu))
    device = get_torch_device(config.is_cuda)

    if config.distributed:
        if config.dist_url == "env://" and config.rank == -1:
            config.rank = int(os.environ["RANK"])
        if config.multiprocessing_distributed:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            config.rank = config.rank * ngpus_per_node + gpu
        dist.init_process_group(backend=config.dist_backend,
                                init_method=config.dist_url,
                                world_size=config.world_size,
                                rank=config.rank)

    logging.info('===> Configurations')
    dconfig = vars(config)
    for k in dconfig:
        logging.info('    {}: {}'.format(k, dconfig[k]))

    DatasetClass = load_dataset(config.dataset)
    if config.test_original_pointcloud:
        if not DatasetClass.IS_FULL_POINTCLOUD_EVAL:
            raise ValueError(
                'This dataset does not support full pointcloud evaluation.')

    if config.evaluate_original_pointcloud:
        if not config.return_transformation:
            raise ValueError(
                'Pointcloud evaluation requires config.return_transformation=true.'
            )

    if (config.return_transformation ^ config.evaluate_original_pointcloud):
        raise ValueError(
            'Rotation evaluation requires config.evaluate_original_pointcloud=true and '
            'config.return_transformation=true.')

    logging.info('===> Initializing dataloader')
    if config.is_train:
        train_data_loader, train_sampler = initialize_data_loader(
            DatasetClass,
            config,
            phase=config.train_phase,
            num_workers=config.num_workers,
            augment_data=True,
            shuffle=True,
            repeat=True,
            batch_size=config.batch_size,
            limit_numpoints=config.train_limit_numpoints)

        val_data_loader, val_sampler = initialize_data_loader(
            DatasetClass,
            config,
            num_workers=config.num_val_workers,
            phase=config.val_phase,
            augment_data=False,
            shuffle=True,
            repeat=False,
            batch_size=config.val_batch_size,
            limit_numpoints=False)
        if train_data_loader.dataset.NUM_IN_CHANNEL is not None:
            num_in_channel = train_data_loader.dataset.NUM_IN_CHANNEL
        else:
            num_in_channel = 3  # RGB color

        num_labels = train_data_loader.dataset.NUM_LABELS
    else:
        test_data_loader, val_sampler = initialize_data_loader(
            DatasetClass,
            config,
            num_workers=config.num_workers,
            phase=config.test_phase,
            augment_data=False,
            shuffle=False,
            repeat=False,
            batch_size=config.test_batch_size,
            limit_numpoints=False)
        if test_data_loader.dataset.NUM_IN_CHANNEL is not None:
            num_in_channel = test_data_loader.dataset.NUM_IN_CHANNEL
        else:
            num_in_channel = 3  # RGB color

        num_labels = test_data_loader.dataset.NUM_LABELS

    logging.info('===> Building model')
    NetClass = load_model(config.model)
    if config.wrapper_type == 'None':
        model = NetClass(num_in_channel, num_labels, config)
        logging.info('===> Number of trainable parameters: {}: {}'.format(
            NetClass.__name__, count_parameters(model)))
    else:
        wrapper = load_wrapper(config.wrapper_type)
        model = wrapper(NetClass, num_in_channel, num_labels, config)
        logging.info('===> Number of trainable parameters: {}: {}'.format(
            wrapper.__name__ + NetClass.__name__, count_parameters(model)))

    logging.info(model)

    if config.weights == 'modelzoo':  # Load modelzoo weights if possible.
        logging.info('===> Loading modelzoo weights')
        model.preload_modelzoo()
    # Load weights if specified by the parameter.
    elif config.weights.lower() != 'none':
        logging.info('===> Loading weights: ' + config.weights)
        state = torch.load(config.weights)
        if config.weights_for_inner_model:
            model.model.load_state_dict(state['state_dict'])
        else:
            if config.lenient_weight_loading:
                matched_weights = load_state_with_same_shape(
                    model, state['state_dict'])
                model_dict = model.state_dict()
                model_dict.update(matched_weights)
                model.load_state_dict(model_dict)
            else:
                init_model_from_weights(model, state, freeze_bb=False)

    if config.distributed:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if config.gpu is not None:
            torch.cuda.set_device(config.gpu)
            model.cuda(config.gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            config.batch_size = int(config.batch_size / ngpus_per_node)
            config.num_workers = int(
                (config.num_workers + ngpus_per_node - 1) / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[config.gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)

    if config.is_train:
        train(model,
              train_data_loader,
              val_data_loader,
              config,
              train_sampler=train_sampler,
              ngpus_per_node=ngpus_per_node)
    else:
        test(model, test_data_loader, config)
if __name__ == "__main__":
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

    PARA = params_setup()
    create_model_dir(PARA)
    logging_config_setup(PARA)
    print_parameters(PARA)

    GRAPH, MODEL = create_graph(PARA)

    with tf.Session(config=config_setup(), graph=GRAPH) as sess:
        sess.run(tf.global_variables_initializer())
        load_weights(PARA, sess, MODEL)

        COORD = tf.train.Coordinator()
        THREADS = tf.train.start_queue_runners(sess=sess, coord=COORD)
        try:
            if PARA.mode == 'pretrain':
                pretrain(PARA, sess, MODEL)
            elif PARA.mode == 'rl':
                policy_gradient(PARA, sess, MODEL)
            elif PARA.mode == 'test':
                test(PARA, sess, MODEL)
        except KeyboardInterrupt:
            print('KeyboardInterrupt')
        finally:
            print('Stop')
            COORD.request_stop()
            COORD.join(THREADS)
Beispiel #14
0
    args = parser.parse_args()

    if args.work == None:
        print('must specify work')

    elif args.work == 'preprocess' or args.work == 'pre':
        from lib.preprocess import do_preprocess
        if args.file != None:
            do_preprocess(args.file[0], args.file[1])
        else:
            do_preprocess(test_exist=True)
        os.system('python3 main.py -w train -e 5 -b 256')

    elif args.work == 'test':
        from lib.test import test
        test(args.file[0], args.file[1])

    elif args.work == 'train':
        from lib.train import train
        history = train(model_path=args.model,
                        epochs=args.epochs,
                        batch_size=args.batch_size,
                        seed=args.random_seed)
        if args.plot == True:
            from lib.plot import plot
            plot(history)

    elif args.work == 'semi':
        from lib.semiParse import semiParse
        semiParse()
def main(config, init_distributed=False):

    if not torch.cuda.is_available():
        raise Exception('No GPUs FOUND.')

    # setup initial seed
    torch.cuda.set_device(config.device_id)
    torch.manual_seed(config.seed)
    torch.cuda.manual_seed(config.seed)

    device = config.device_id
    distributed = config.distributed_world_size > 1

    if init_distributed:
        config.distributed_rank = distributed_utils.distributed_init(config)

    setup_logging(config)

    logging.info('===> Configurations')
    dconfig = vars(config)
    for k in dconfig:
        logging.info('    {}: {}'.format(k, dconfig[k]))

    DatasetClass = load_dataset(config.dataset)
    if config.test_original_pointcloud:
        if not DatasetClass.IS_FULL_POINTCLOUD_EVAL:
            raise ValueError(
                'This dataset does not support full pointcloud evaluation.')

    if config.evaluate_original_pointcloud:
        if not config.return_transformation:
            raise ValueError(
                'Pointcloud evaluation requires config.return_transformation=true.'
            )

    if (config.return_transformation ^ config.evaluate_original_pointcloud):
        raise ValueError(
            'Rotation evaluation requires config.evaluate_original_pointcloud=true and '
            'config.return_transformation=true.')

    logging.info('===> Initializing dataloader')
    if config.is_train:
        train_data_loader = initialize_data_loader(
            DatasetClass,
            config,
            phase=config.train_phase,
            num_workers=config.num_workers,
            augment_data=True,
            shuffle=True,
            repeat=True,
            batch_size=config.batch_size,
            limit_numpoints=config.train_limit_numpoints)

        val_data_loader = initialize_data_loader(
            DatasetClass,
            config,
            num_workers=config.num_val_workers,
            phase=config.val_phase,
            augment_data=False,
            shuffle=True,
            repeat=False,
            batch_size=config.val_batch_size,
            limit_numpoints=False)

        if train_data_loader.dataset.NUM_IN_CHANNEL is not None:
            num_in_channel = train_data_loader.dataset.NUM_IN_CHANNEL
        else:
            num_in_channel = 3  # RGB color

        num_labels = train_data_loader.dataset.NUM_LABELS

    else:

        test_data_loader = initialize_data_loader(
            DatasetClass,
            config,
            num_workers=config.num_workers,
            phase=config.test_phase,
            augment_data=False,
            shuffle=False,
            repeat=False,
            batch_size=config.test_batch_size,
            limit_numpoints=False)

        if test_data_loader.dataset.NUM_IN_CHANNEL is not None:
            num_in_channel = test_data_loader.dataset.NUM_IN_CHANNEL
        else:
            num_in_channel = 3  # RGB color

        num_labels = test_data_loader.dataset.NUM_LABELS

    logging.info('===> Building model')
    NetClass = load_model(config.model)
    if config.wrapper_type == 'None':
        model = NetClass(num_in_channel, num_labels, config)
        logging.info('===> Number of trainable parameters: {}: {}'.format(
            NetClass.__name__, count_parameters(model)))
    else:
        wrapper = load_wrapper(config.wrapper_type)
        model = wrapper(NetClass, num_in_channel, num_labels, config)
        logging.info('===> Number of trainable parameters: {}: {}'.format(
            wrapper.__name__ + NetClass.__name__, count_parameters(model)))

    logging.info(model)

    if config.weights == 'modelzoo':  # Load modelzoo weights if possible.
        logging.info('===> Loading modelzoo weights')
        model.preload_modelzoo()

    # Load weights if specified by the parameter.
    elif config.weights.lower() != 'none':
        logging.info('===> Loading weights: ' + config.weights)
        # state = torch.load(config.weights)
        state = torch.load(
            config.weights,
            map_location=lambda s, l: default_restore_location(s, 'cpu'))

        if config.weights_for_inner_model:
            model.model.load_state_dict(state['state_dict'])
        else:
            if config.lenient_weight_loading:
                matched_weights = load_state_with_same_shape(
                    model, state['state_dict'])
                model_dict = model.state_dict()
                model_dict.update(matched_weights)
                model.load_state_dict(model_dict)
            else:
                model.load_state_dict(state['state_dict'])

    model = model.cuda()
    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            module=model,
            device_ids=[device],
            output_device=device,
            broadcast_buffers=False,
            bucket_cap_mb=config.bucket_cap_mb)

    if config.is_train:
        train(model, train_data_loader, val_data_loader, config)
    else:
        test(model, test_data_loader, config)
Beispiel #16
0
def train_worker(gpu,
                 num_devices,
                 NetClass,
                 data_loader,
                 val_data_loader,
                 config,
                 transform_data_fn=None):
    if gpu is not None:
        print("Use GPU: {} for training".format(gpu))
        rank = gpu
    addr = 23491
    dist.init_process_group(backend="nccl",
                            init_method="tcp://127.0.0.1:{}".format(addr),
                            world_size=num_devices,
                            rank=rank)

    # replace with DistributedSampler
    if config.multiprocess:
        from lib.dataloader_dist import InfSampler
        sampler = InfSampler(data_loader.dataset)
        data_loader = DataLoader(dataset=data_loader.dataset,
                                 num_workers=data_loader.num_workers,
                                 batch_size=data_loader.batch_size,
                                 collate_fn=data_loader.collate_fn,
                                 worker_init_fn=data_loader.worker_init_fn,
                                 sampler=sampler)

    if data_loader.dataset.NUM_IN_CHANNEL is not None:
        num_in_channel = data_loader.dataset.NUM_IN_CHANNEL
    else:
        num_in_channel = 3
    num_labels = data_loader.dataset.NUM_LABELS

    # load model
    if config.pure_point:
        model = NetClass(num_class=config.num_labels,
                         N=config.num_points,
                         normal_channel=config.num_in_channel)
    else:
        if config.model == 'MixedTransformer':
            model = NetClass(config,
                             num_class=num_labels,
                             N=config.num_points,
                             normal_channel=num_in_channel)
        elif config.model == 'MinkowskiVoxelTransformer':
            model = NetClass(config, num_in_channel, num_labels)
        elif config.model == 'MinkowskiTransformerNet':
            model = NetClass(config, num_in_channel, num_labels)
        elif "Res" in config.model:
            model = NetClass(num_in_channel, num_labels, config)
        else:
            model = NetClass(num_in_channel, num_labels, config)

    if config.weights == 'modelzoo':
        model.preload_modelzoo()
    elif config.weights.lower() != 'none':
        state = torch.load(config.weights)
        # delete the keys containing the attn since it raises size mismatch
        d = {k: v for k, v in state['state' '_dict'].items() if 'map' not in k}
        if config.weights_for_inner_model:
            model.model.load_state_dict(d)
        else:
            if config.lenient_weight_loading:
                matched_weights = load_state_with_same_shape(
                    model, state['state_dict'])
                model_dict = model.state_dict()
                model_dict.update(matched_weights)
                model.load_state_dict(model_dict)
            else:
                model.load_state_dict(d, strict=False)

    torch.cuda.set_device(gpu)
    model.cuda(gpu)
    # use model with DDP
    model = torch.nn.parallel.DistributedDataParallel(
        model, device_ids=[gpu], find_unused_parameters=False)
    # Synchronized batch norm
    model = ME.MinkowskiSyncBatchNorm.convert_sync_batchnorm(model)

    # Set up the train flag for batch normalization
    model.train()

    # Configuration
    data_timer, iter_timer = Timer(), Timer()
    data_time_avg, iter_time_avg = AverageMeter(), AverageMeter()
    regs, losses, scores = AverageMeter(), AverageMeter(), AverageMeter()

    optimizer = initialize_optimizer(model.parameters(), config)
    scheduler = initialize_scheduler(optimizer, config)
    criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label)

    # Train the network
    if rank == 0:
        setup_logger(config)
        logging.info('===> Start training')

    best_val_miou, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True

    if config.resume:
        # Test loaded ckpt first
        v_loss, v_score, v_mAP, v_mIoU = test(model, val_data_loader, config)

        checkpoint_fn = config.resume + '/weights.pth'
        if osp.isfile(checkpoint_fn):
            logging.info("=> loading checkpoint '{}'".format(checkpoint_fn))
            state = torch.load(checkpoint_fn)
            curr_iter = state['iteration'] + 1
            epoch = state['epoch']
            # we skip attention maps because the shape won't match because voxel number is different
            # e.g. copyting a param with shape (23385, 8, 4) to (43529, 8, 4)
            d = {
                k: v
                for k, v in state['state_dict'].items() if 'map' not in k
            }
            # handle those attn maps we don't load from saved dict
            for k in model.state_dict().keys():
                if k in d.keys(): continue
                d[k] = model.state_dict()[k]
            model.load_state_dict(d)
            if config.resume_optimizer:
                scheduler = initialize_scheduler(optimizer,
                                                 config,
                                                 last_step=curr_iter)
                optimizer.load_state_dict(state['optimizer'])
            if 'best_val' in state:
                best_val_miou = state['best_val']
                best_val_iter = state['best_val_iter']
            logging.info("=> loaded checkpoint '{}' (epoch {})".format(
                checkpoint_fn, state['epoch']))
        else:
            raise ValueError(
                "=> no checkpoint found at '{}'".format(checkpoint_fn))

    data_iter = data_loader.__iter__()
    device = gpu  # multitrain fed in the device
    if config.dataset == "SemanticKITTI":
        num_class = 19
        config.normalize_color = False
        config.xyz_input = False
        val_freq_ = config.val_freq
        config.val_freq = config.val_freq * 10  # origianl val_freq_
    elif config.dataset == 'S3DIS':
        num_class = 13
        config.normalize_color = False
        config.xyz_input = False
        val_freq_ = config.val_freq
    elif config.dataset == "Nuscenes":
        num_class = 16
        config.normalize_color = False
        config.xyz_input = False
        val_freq_ = config.val_freq
        config.val_freq = config.val_freq * 50
    else:
        val_freq_ = config.val_freq
        num_class = 20

    while is_training:

        total_correct_class = torch.zeros(num_class, device=device)
        total_iou_deno_class = torch.zeros(num_class, device=device)

        for iteration in range(len(data_loader) // config.iter_size):

            optimizer.zero_grad()
            data_time, batch_loss = 0, 0
            iter_timer.tic()

            if curr_iter >= config.max_iter:
                # if curr_iter >= max(config.max_iter, config.epochs*(len(data_loader) // config.iter_size):
                is_training = False
                break
            elif curr_iter >= config.max_iter * (2 / 3):
                config.val_freq = val_freq_ * 2  # valid more freq on lower half

            for sub_iter in range(config.iter_size):

                # Get training data
                data_timer.tic()
                if config.return_transformation:
                    coords, input, target, _, _, pointcloud, transformation = data_iter.next(
                    )
                else:
                    coords, input, target, _, _ = data_iter.next(
                    )  # ignore unique_map and inverse_map

                if config.use_aux:
                    assert target.shape[1] == 2
                    aux = target[:, 1]
                    target = target[:, 0]
                else:
                    aux = None

                # For some networks, making the network invariant to even, odd coords is important
                coords[:, 1:] += (torch.rand(3) * 100).type_as(coords)

                # Preprocess input
                if config.normalize_color:
                    input[:, :3] = input[:, :3] / input[:, :3].max() - 0.5
                    coords_norm = coords[:, 1:] / coords[:, 1:].max() - 0.5

                # cat xyz into the rgb feature
                if config.xyz_input:
                    input = torch.cat([coords_norm, input], dim=1)
                # print(device)

                sinput = SparseTensor(input, coords, device=device)

                # d = {}
                # d['coord'] = sinput.C
                # d['feat'] = sinput.F
                # torch.save(d, 'voxel.pth')
                # import ipdb; ipdb.set_trace()

                data_time += data_timer.toc(False)
                # model.initialize_coords(*init_args)
                if aux is not None:
                    soutput = model(sinput, aux)
                elif config.enable_point_branch:
                    soutput = model(sinput,
                                    iter_=curr_iter / config.max_iter,
                                    enable_point_branch=True)
                else:
                    soutput = model(
                        sinput, iter_=curr_iter / config.max_iter
                    )  # feed in the progress of training for annealing inside the model
                    # soutput = model(sinput)
                # The output of the network is not sorted
                target = target.view(-1).long().to(device)

                loss = criterion(soutput.F, target.long())

                # ====== other loss regs =====
                cur_loss = torch.tensor([0.], device=device)
                if hasattr(model, 'module.block1'):
                    cur_loss = torch.tensor([0.], device=device)

                    if hasattr(model.module.block1[0], 'vq_loss'):
                        if model.block1[0].vq_loss is not None:
                            cur_loss = torch.tensor([0.], device=device)
                            for n, m in model.named_children():
                                if 'block' in n:
                                    cur_loss += m[
                                        0].vq_loss  # m is the nn.Sequential obj, m[0] is the TRBlock
                            logging.info(
                                'Cur Loss: {}, Cur vq_loss: {}'.format(
                                    loss, cur_loss))
                            loss += cur_loss

                    if hasattr(model.module.block1[0], 'diverse_loss'):
                        if model.block1[0].diverse_loss is not None:
                            cur_loss = torch.tensor([0.], device=device)
                            for n, m in model.named_children():
                                if 'block' in n:
                                    cur_loss += m[
                                        0].diverse_loss  # m is the nn.Sequential obj, m[0] is the TRBlock
                            logging.info(
                                'Cur Loss: {}, Cur diverse _loss: {}'.format(
                                    loss, cur_loss))
                            loss += cur_loss

                    if hasattr(model.module.block1[0], 'label_reg'):
                        if model.block1[0].label_reg is not None:
                            cur_loss = torch.tensor([0.], device=device)
                            for n, m in model.named_children():
                                if 'block' in n:
                                    cur_loss += m[
                                        0].label_reg  # m is the nn.Sequential obj, m[0] is the TRBlock
                            # logging.info('Cur Loss: {}, Cur diverse _loss: {}'.format(loss, cur_loss))
                            loss += cur_loss

                # Compute and accumulate gradient
                loss /= config.iter_size
                batch_loss += loss.item()
                if not config.use_sam:
                    loss.backward()
                else:
                    with model.no_sync():
                        loss.backward()

            # Update number of steps
            if not config.use_sam:
                optimizer.step()
            else:
                optimizer.first_step(zero_grad=True)
                soutput = model(sinput,
                                iter_=curr_iter / config.max_iter,
                                aux=starget)
                criterion(soutput.F, target.long()).backward()
                optimizer.second_step(zero_grad=True)

            if config.lr_warmup is None:
                scheduler.step()
            else:
                if curr_iter >= config.lr_warmup:
                    scheduler.step()
                else:
                    for g in optimizer.param_groups:
                        g['lr'] = config.lr * (iteration +
                                               1) / config.lr_warmup

            # CLEAR CACHE!
            torch.cuda.empty_cache()

            data_time_avg.update(data_time)
            iter_time_avg.update(iter_timer.toc(False))

            pred = get_prediction(data_loader.dataset, soutput.F, target)
            score = precision_at_one(pred, target, ignore_label=-1)

            regs.update(cur_loss.item(), target.size(0))
            losses.update(batch_loss, target.size(0))
            scores.update(score, target.size(0))

            # calc the train-iou
            for l in range(num_class):
                total_correct_class[l] += ((pred == l) & (target == l)).sum()
                total_iou_deno_class[l] += (((pred == l) & (target != -1)) |
                                            (target == l)).sum()

            if curr_iter % config.stat_freq == 0 or curr_iter == 1:
                lrs = ', '.join(
                    ['{:.3e}'.format(g['lr']) for g in optimizer.param_groups])
                IoU = ((total_correct_class) /
                       (total_iou_deno_class + 1e-6)).mean() * 100.
                debug_str = "===> Epoch[{}]({}/{}): Loss {:.4f}\tLR: {}\t".format(
                    epoch, curr_iter,
                    len(data_loader) // config.iter_size, losses.avg, lrs)
                debug_str += "Score {:.3f}\tIoU {:.3f}\tData time: {:.4f}, Iter time: {:.4f}".format(
                    scores.avg, IoU.item(), data_time_avg.avg,
                    iter_time_avg.avg)
                if regs.avg > 0:
                    debug_str += "\n Additional Reg Loss {:.3f}".format(
                        regs.avg)

                if rank == 0:
                    logging.info(debug_str)
                # Reset timers
                data_time_avg.reset()
                iter_time_avg.reset()
                # Write logs
                losses.reset()
                scores.reset()

            # only save status on the 1st gpu
            if rank == 0:

                # Save current status, save before val to prevent occational mem overflow
                if curr_iter % config.save_freq == 0:
                    checkpoint(model,
                               optimizer,
                               epoch,
                               curr_iter,
                               config,
                               best_val_miou,
                               best_val_iter,
                               save_inter=True)

                # Validation
                if curr_iter % config.val_freq == 0:
                    val_miou = validate(model, val_data_loader, None,
                                        curr_iter, config, transform_data_fn
                                        )  # feedin None for SummaryWriter args
                    if val_miou > best_val_miou:
                        best_val_miou = val_miou
                        best_val_iter = curr_iter
                        checkpoint(model,
                                   optimizer,
                                   epoch,
                                   curr_iter,
                                   config,
                                   best_val_miou,
                                   best_val_iter,
                                   "best_val",
                                   save_inter=True)
                    if rank == 0:
                        logging.info(
                            "Current best mIoU: {:.3f} at iter {}".format(
                                best_val_miou, best_val_iter))

                    # Recover back
                    model.train()

            # End of iteration
            curr_iter += 1

        IoU = (total_correct_class) / (total_iou_deno_class + 1e-6)
        if rank == 0:
            logging.info('train point avg class IoU: %f' %
                         ((IoU).mean() * 100.))

        epoch += 1

    # Explicit memory cleanup
    if hasattr(data_iter, 'cleanup'):
        data_iter.cleanup()

    # Save the final model
    if rank == 0:
        checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou,
                   best_val_iter)
        v_loss, v_score, v_mAP, val_mIoU = test(model, val_data_loader, config)

        if val_miou > best_val_miou and rank == 0:
            best_val_miou = val_miou
            best_val_iter = curr_iter
            logging.info("Final best miou: {}  at iter {} ".format(
                val_miou, curr_iter))
            checkpoint(model, optimizer, epoch, curr_iter, config,
                       best_val_miou, best_val_iter, "best_val")

            logging.info("Current best mIoU: {:.3f} at iter {}".format(
                best_val_miou, best_val_iter))
Beispiel #17
0
def main():
    config = get_config()

    if config.test_config:
        json_config = json.load(open(config.test_config, 'r'))
        json_config['is_train'] = False
        json_config['weights'] = config.weights
        config = edict(json_config)
    elif config.resume:
        json_config = json.load(open(config.resume + '/config.json', 'r'))
        json_config['resume'] = config.resume
        config = edict(json_config)

    if config.is_cuda and not torch.cuda.is_available():
        raise Exception("No GPU found")
    device = get_torch_device(config.is_cuda)

    # torch.set_num_threads(config.threads)
    # torch.manual_seed(config.seed)
    # if config.is_cuda:
    #   torch.cuda.manual_seed(config.seed)

    logging.info('===> Configurations')
    dconfig = vars(config)
    for k in dconfig:
        logging.info('    {}: {}'.format(k, dconfig[k]))

    DatasetClass = load_dataset(config.dataset)
    logging.info('===> Initializing dataloader')

    if config.is_train:
        setup_seed(2021)

        train_data_loader = initialize_data_loader(
            DatasetClass,
            config,
            phase=config.train_phase,
            # threads=config.threads,
            threads=4,
            augment_data=True,
            elastic_distortion=config.train_elastic_distortion,
            # elastic_distortion=False,
            # shuffle=True,
            shuffle=False,
            # repeat=True,
            repeat=False,
            batch_size=config.batch_size,
            # batch_size=8,
            limit_numpoints=config.train_limit_numpoints)

        # dat = iter(train_data_loader).__next__()
        # import ipdb; ipdb.set_trace()

        val_data_loader = initialize_data_loader(
            DatasetClass,
            config,
            # threads=0,
            threads=config.val_threads,
            phase=config.val_phase,
            augment_data=False,
            elastic_distortion=config.test_elastic_distortion,
            shuffle=False,
            repeat=False,
            # batch_size=config.val_batch_size,
            batch_size=8,
            limit_numpoints=False)

        # dat = iter(val_data_loader).__next__()
        # import ipdb; ipdb.set_trace()

        if train_data_loader.dataset.NUM_IN_CHANNEL is not None:
            num_in_channel = train_data_loader.dataset.NUM_IN_CHANNEL
        else:
            num_in_channel = 3

        num_labels = train_data_loader.dataset.NUM_LABELS
    else:
        test_data_loader = initialize_data_loader(
            DatasetClass,
            config,
            threads=config.threads,
            phase=config.test_phase,
            augment_data=False,
            elastic_distortion=config.test_elastic_distortion,
            shuffle=False,
            repeat=False,
            batch_size=config.test_batch_size,
            limit_numpoints=False)
        if test_data_loader.dataset.NUM_IN_CHANNEL is not None:
            num_in_channel = test_data_loader.dataset.NUM_IN_CHANNEL
        else:
            num_in_channel = 3

        num_labels = test_data_loader.dataset.NUM_LABELS

    logging.info('===> Building model')
    NetClass = load_model(config.model)
    model = NetClass(num_in_channel, num_labels, config)
    logging.info('===> Number of trainable parameters: {}: {}'.format(
        NetClass.__name__, count_parameters(model)))
    logging.info(model)

    # Set the number of threads
    # ME.initialize_nthreads(12, D=3)

    model = model.to(device)

    if config.weights == 'modelzoo':  # Load modelzoo weights if possible.
        logging.info('===> Loading modelzoo weights')
        model.preload_modelzoo()
    # Load weights if specified by the parameter.
    elif config.weights.lower() != 'none':
        logging.info('===> Loading weights: ' + config.weights)
        state = torch.load(config.weights)
        if config.weights_for_inner_model:
            model.model.load_state_dict(state['state_dict'])
        else:
            if config.lenient_weight_loading:
                matched_weights = load_state_with_same_shape(
                    model, state['state_dict'])
                model_dict = model.state_dict()
                model_dict.update(matched_weights)
                model.load_state_dict(model_dict)
            else:
                model.load_state_dict(state['state_dict'])

    if config.is_train:
        train(model, train_data_loader, val_data_loader, config)
    else:
        test(model, test_data_loader, config)
Beispiel #18
0
def train_distill(model,
                  data_loader,
                  val_data_loader,
                  config,
                  transform_data_fn=None):
    '''
    the distillation training
    some cfgs here
    '''

    # distill_lambda = 1
    # distill_lambda = 0.33
    distill_lambda = 0.67

    # TWO_STAGE=True: Transformer is first trained with L2 loss to match ResNet's activation, and then it fintunes like normal training on the second stage.
    # TWO_STAGE=False: Transformer trains with combined loss

    TWO_STAGE = False
    # STAGE_PERCENTAGE = 0.7

    device = get_torch_device(config.is_cuda)
    # Set up the train flag for batch normalization
    model.train()

    # Configuration
    data_timer, iter_timer = Timer(), Timer()
    data_time_avg, iter_time_avg = AverageMeter(), AverageMeter()
    losses, scores = AverageMeter(), AverageMeter()

    optimizer = initialize_optimizer(model.parameters(), config)
    scheduler = initialize_scheduler(optimizer, config)
    criterion = nn.CrossEntropyLoss(ignore_index=config.ignore_label)

    # Train the network
    logging.info('===> Start training')
    best_val_miou, best_val_iter, curr_iter, epoch, is_training = 0, 0, 1, 1, True

    # TODO:
    # load the sub-model only
    # FIXME: some dirty hard-written stuff, only supporting current state

    tch_model_cls = load_model('Res16UNet18A')
    tch_model = tch_model_cls(3, 20, config).to(device)

    # checkpoint_fn = "/home/zhaotianchen/project/point-transformer/SpatioTemporalSegmentation-ScanNet/outputs/ScannetSparseVoxelizationDataset/Res16UNet18A/resnet_base/weights.pth"
    checkpoint_fn = "/home/zhaotianchen/project/point-transformer/SpatioTemporalSegmentation-ScanNet/outputs/ScannetSparseVoxelizationDataset/Res16UNet18A/Res18A/weights.pth"  # voxel-size: 0.05
    assert osp.isfile(checkpoint_fn)
    logging.info("=> loading checkpoint '{}'".format(checkpoint_fn))
    state = torch.load(checkpoint_fn)
    d = {k: v for k, v in state['state_dict'].items() if 'map' not in k}
    tch_model.load_state_dict(d)
    if 'best_val' in state:
        best_val_miou = state['best_val']
        best_val_iter = state['best_val_iter']
    logging.info("=> loaded checkpoint '{}' (epoch {})".format(
        checkpoint_fn, state['epoch']))

    if config.resume:
        raise NotImplementedError
        # Test loaded ckpt first

        # checkpoint_fn = config.resume + '/weights.pth'
        # if osp.isfile(checkpoint_fn):
        # logging.info("=> loading checkpoint '{}'".format(checkpoint_fn))
        # state = torch.load(checkpoint_fn)
        # curr_iter = state['iteration'] + 1
        # epoch = state['epoch']
        # d = {k:v for k,v in state['state_dict'].items() if 'map' not in k }
        # model.load_state_dict(d)
        # if config.resume_optimizer:
        # scheduler = initialize_scheduler(optimizer, config, last_step=curr_iter)
        # optimizer.load_state_dict(state['optimizer'])
        # if 'best_val' in state:
        # best_val_miou = state['best_val']
        # best_val_iter = state['best_val_iter']
        # logging.info("=> loaded checkpoint '{}' (epoch {})".format(checkpoint_fn, state['epoch']))
        # else:
        # raise ValueError("=> no checkpoint found at '{}'".format(checkpoint_fn))

    # test after loading the ckpt
    v_loss, v_score, v_mAP, v_mIoU = test(tch_model, val_data_loader, config)
    logging.info('Tch model tested, bes_miou: {}'.format(v_mIoU))

    data_iter = data_loader.__iter__()
    while is_training:

        num_class = 20
        total_correct_class = torch.zeros(num_class, device=device)
        total_iou_deno_class = torch.zeros(num_class, device=device)

        total_iteration = len(data_loader) // config.iter_size
        for iteration in range(total_iteration):

            # NOTE: for single stage distillation, L2 loss might be too large at first
            # so we added a warmup training that don't use L2 loss
            if iteration < 0:
                use_distill = False
            else:
                use_distill = True

            # Stage 1 / Stage 2 boundary
            if TWO_STAGE:
                stage_boundary = int(total_iteration * STAGE_PERCENTAGE)

            optimizer.zero_grad()
            data_time, batch_loss = 0, 0
            iter_timer.tic()

            for sub_iter in range(config.iter_size):
                # Get training data
                data_timer.tic()
                if config.return_transformation:
                    coords, input, target, _, _, pointcloud, transformation = data_iter.next(
                    )
                else:
                    coords, input, target, _, _ = data_iter.next(
                    )  # ignore unique_map and inverse_map

                if config.use_aux:
                    assert target.shape[1] == 2
                    aux = target[:, 1]
                    target = target[:, 0]
                else:
                    aux = None

                # For some networks, making the network invariant to even, odd coords is important
                coords[:, 1:] += (torch.rand(3) * 100).type_as(coords)

                # Preprocess input
                if config.normalize_color:
                    input[:, :3] = input[:, :3] / 255. - 0.5
                    coords_norm = coords[:, 1:] / coords[:, 1:].max() - 0.5

                # cat xyz into the rgb feature
                if config.xyz_input:
                    input = torch.cat([coords_norm, input], dim=1)

                sinput = SparseTensor(input, coords, device=device)

                # TODO: return both-models
                # in order to not breaking the valid interface, use a get_loss to get the regsitered loss

                data_time += data_timer.toc(False)
                # model.initialize_coords(*init_args)
                if aux is not None:
                    raise NotImplementedError

                # flatten ground truth tensor
                target = target.view(-1).long().to(device)

                if TWO_STAGE:
                    if iteration < stage_boundary:
                        # Stage 1: train transformer on L2 loss
                        soutput, anchor = model(sinput, save_anchor=True)
                        # Make sure gradient don't flow to teacher model
                        with torch.no_grad():
                            _, tch_anchor = tch_model(sinput, save_anchor=True)
                        loss = DistillLoss(tch_anchor, anchor)
                    else:
                        # Stage 2: finetune transformer on Cross-Entropy
                        soutput = model(sinput)
                        loss = criterion(soutput.F, target.long())
                else:
                    if use_distill:  # after warm up
                        soutput, anchor = model(sinput, save_anchor=True)
                        # if pretrained teacher, do not let the grad flow to teacher to update its params
                        with torch.no_grad():
                            tch_soutput, tch_anchor = tch_model(
                                sinput, save_anchor=True)

                    else:  # warming up
                        soutput = model(sinput)
                    # The output of the network is not sorted
                    loss = criterion(soutput.F, target.long())
                    #  Add L2 loss if use distillation
                    if use_distill:
                        distill_loss = DistillLoss(tch_anchor,
                                                   anchor) * distill_lambda
                        loss += distill_loss

                # Compute and accumulate gradient
                loss /= config.iter_size
                batch_loss += loss.item()
                loss.backward()

            # Update number of steps
            optimizer.step()
            scheduler.step()

            # CLEAR CACHE!
            torch.cuda.empty_cache()

            data_time_avg.update(data_time)
            iter_time_avg.update(iter_timer.toc(False))

            pred = get_prediction(data_loader.dataset, soutput.F, target)
            score = precision_at_one(pred, target, ignore_label=-1)
            losses.update(batch_loss, target.size(0))
            scores.update(score, target.size(0))

            # calc the train-iou
            for l in range(num_class):
                total_correct_class[l] += ((pred == l) & (target == l)).sum()
                total_iou_deno_class[l] += (((pred == l) & (target != -1)) |
                                            (target == l)).sum()

            if curr_iter >= config.max_iter:
                is_training = False
                break

            if curr_iter % config.stat_freq == 0 or curr_iter == 1:
                lrs = ', '.join(
                    ['{:.3e}'.format(x) for x in scheduler.get_lr()])
                debug_str = "[{}] ===> Epoch[{}]({}/{}): Loss {:.4f}\tLR: {}\t".format(
                    config.log_dir, epoch, curr_iter,
                    len(data_loader) // config.iter_size, losses.avg, lrs)
                debug_str += "Score {:.3f}\tData time: {:.4f}, Iter time: {:.4f}".format(
                    scores.avg, data_time_avg.avg, iter_time_avg.avg)
                logging.info(debug_str)
                if use_distill and not TWO_STAGE:
                    logging.info('Loss {} Distill Loss:{}'.format(
                        loss, distill_loss))
                # Reset timers
                data_time_avg.reset()
                iter_time_avg.reset()
                losses.reset()
                scores.reset()

            # Save current status, save before val to prevent occational mem overflow
            if curr_iter % config.save_freq == 0:
                checkpoint(model,
                           optimizer,
                           epoch,
                           curr_iter,
                           config,
                           best_val_miou,
                           best_val_iter,
                           save_inter=True)

            # Validation
            if curr_iter % config.val_freq == 0:
                val_miou = validate(model, val_data_loader, None, curr_iter,
                                    config, transform_data_fn)
                if val_miou > best_val_miou:
                    best_val_miou = val_miou
                    best_val_iter = curr_iter
                    checkpoint(model,
                               optimizer,
                               epoch,
                               curr_iter,
                               config,
                               best_val_miou,
                               best_val_iter,
                               "best_val",
                               save_inter=True)
                logging.info("Current best mIoU: {:.3f} at iter {}".format(
                    best_val_miou, best_val_iter))

                # Recover back
                model.train()

            # End of iteration
            curr_iter += 1

        IoU = (total_correct_class) / (total_iou_deno_class + 1e-6)
        logging.info('train point avg class IoU: %f' % ((IoU).mean() * 100.))

        epoch += 1

    # Explicit memory cleanup
    if hasattr(data_iter, 'cleanup'):
        data_iter.cleanup()

    # Save the final model
    checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou,
               best_val_iter)
    v_loss, v_score, v_mAP, val_miou = test(model, val_data_loader, config)
    if val_miou > best_val_miou:
        best_val_miou = val_miou
        best_val_iter = curr_iter
        checkpoint(model, optimizer, epoch, curr_iter, config, best_val_miou,
                   best_val_iter, "best_val")
    logging.info("Current best mIoU: {:.3f} at iter {}".format(
        best_val_miou, best_val_iter))
Beispiel #19
0
#%%
create_dir(para.model_dir)
create_dir(para.output_dir)
json_path = para.model_dir + '/parameters.json'
json.dump(vars(para), open(json_path, 'w'), indent=4)

# %%
graph = tf.Graph()
# %%
graph, model, data_generator = create_graph(para)

# %%
with tf.Session(config=config_setup(), graph=graph) as sess:
    sess.run(tf.global_variables_initializer())
    load_weights(para, sess, model)
    print_num_of_trainable_parameters()
    train(para, sess, model, data_generator)
# %%
para.mode = 'test'
graph, model, data_generator = create_graph(para)
with tf.Session(config=config_setup(), graph=graph) as sess:
    sess.run(tf.global_variables_initializer())
    load_weights(para, sess, model)
    print_num_of_trainable_parameters()
    test(para, sess, model, data_generator)

# %%
pred_df = pd.read_parquet(
    os.path.join(para.output_dir, para.data_set + '_predict_output.parquet'))
pred_df.head(10)
Beispiel #20
0
def main():
    config = get_config()

    if config.is_cuda and not torch.cuda.is_available():
        raise Exception("No GPU found")

    # torch.set_num_threads(config.threads)
    torch.manual_seed(config.seed)
    if config.is_cuda:
        torch.cuda.manual_seed(config.seed)

    logging.info('===> Configurations')
    dconfig = vars(config)
    for k in dconfig:
        logging.info('    {}: {}'.format(k, dconfig[k]))

    DatasetClass = load_dataset(config.dataset)

    logging.info('===> Initializing dataloader')
    if config.is_train:
        train_data_loader = initialize_data_loader(
            DatasetClass,
            config,
            phase=config.train_phase,
            threads=config.threads,
            augment_data=True,
            shuffle=True,
            repeat=True,
            batch_size=config.batch_size,
            limit_numpoints=config.train_limit_numpoints)
        val_data_loader = initialize_data_loader(
            DatasetClass,
            config,
            threads=config.val_threads,
            phase=config.val_phase,
            augment_data=False,
            shuffle=False,
            repeat=False,
            batch_size=config.val_batch_size,
            limit_numpoints=False)
        dataset = train_data_loader.dataset
    else:
        test_data_loader = initialize_data_loader(
            DatasetClass,
            config,
            threads=config.threads,
            phase=config.test_phase,
            augment_data=False,
            shuffle=False,
            repeat=False,
            batch_size=config.test_batch_size,
            limit_numpoints=False)
        dataset = test_data_loader.dataset

    logging.info('===> Building model')
    pipeline_model = load_pipeline(config, dataset)
    logging.info(
        f'===> Number of trainable parameters: {count_parameters(pipeline_model)}'
    )

    # Load weights if specified by the parameter.
    if config.weights.lower() != 'none':
        logging.info('===> Loading weights: ' + config.weights)
        state = torch.load(config.weights)
        pipeline_model.load_state_dict(
            state['state_dict'], strict=(not config.lenient_weight_loading))
    if config.pretrained_weights.lower() != 'none':
        logging.info('===> Loading pretrained weights: ' +
                     config.pretrained_weights)
        state = torch.load(config.pretrained_weights)
        pipeline_model.load_pretrained_weights(state['state_dict'])

    if config.is_train:
        train(pipeline_model, train_data_loader, val_data_loader, config)
    else:
        test(pipeline_model, test_data_loader, config)
Beispiel #21
0
def main(config):
    # load the configurations
    setup_logging()
    if os.path.exists('config.yaml'):
        logging.info('===> Loading exsiting config file')
        config = OmegaConf.load('config.yaml')
        logging.info('===> Loaded exsiting config file')
    logging.info(config.pretty())

    # Create Dataset and Dataloader
    if config.data.dataset == 'sunrgbd':
        from lib.datasets.sunrgbd.sunrgbd_detection_dataset import SunrgbdDetectionVotesDataset, MAX_NUM_OBJ
        from lib.datasets.sunrgbd.model_util_sunrgbd import SunrgbdDatasetConfig
        dataset_config = SunrgbdDatasetConfig()
        train_dataset = SunrgbdDetectionVotesDataset(
            'train',
            num_points=config.data.num_points,
            augment=True,
            use_color=config.data.use_color,
            use_height=(not config.data.no_height),
            use_v1=(not config.data.use_sunrgbd_v2),
            data_ratio=config.data.data_ratio)
        test_dataset = SunrgbdDetectionVotesDataset(
            'val',
            num_points=config.data.num_points,
            augment=False,
            use_color=config.data.use_color,
            use_height=(not config.data.no_height),
            use_v1=(not config.data.use_sunrgbd_v2))
    elif config.data.dataset == 'scannet':
        from lib.datasets.scannet.scannet_detection_dataset import ScannetDetectionDataset, MAX_NUM_OBJ
        from lib.datasets.scannet.model_util_scannet import ScannetDatasetConfig
        dataset_config = ScannetDatasetConfig()
        train_dataset = ScannetDetectionDataset(
            'train',
            num_points=config.data.num_points,
            augment=True,
            use_color=config.data.use_color,
            use_height=(not config.data.no_height),
            data_ratio=config.data.data_ratio)
        test_dataset = ScannetDetectionDataset(
            'val',
            num_points=config.data.num_points,
            augment=False,
            use_color=config.data.use_color,
            use_height=(not config.data.no_height))
    else:
        logging.info('Unknown dataset %s. Exiting...' % (config.data.dataset))
        exit(-1)

    COLLATE_FN = None
    if config.data.voxelization:
        from models.backbone.sparseconv.voxelized_dataset import VoxelizationDataset, collate_fn
        train_dataset = VoxelizationDataset(train_dataset,
                                            config.data.voxel_size)
        test_dataset = VoxelizationDataset(test_dataset,
                                           config.data.voxel_size)
        COLLATE_FN = collate_fn

    logging.info('training: {}, testing: {}'.format(len(train_dataset),
                                                    len(test_dataset)))

    train_dataloader = DataLoader(train_dataset,
                                  batch_size=config.data.batch_size,
                                  shuffle=True,
                                  num_workers=config.data.num_workers,
                                  worker_init_fn=my_worker_init_fn,
                                  collate_fn=COLLATE_FN)

    test_dataloader = DataLoader(test_dataset,
                                 batch_size=config.data.num_workers,
                                 shuffle=True,
                                 num_workers=config.data.num_workers,
                                 worker_init_fn=my_worker_init_fn,
                                 collate_fn=COLLATE_FN)

    logging.info('train dataloader: {}, test dataloader: {}'.format(
        len(train_dataloader), len(test_dataloader)))

    # Init the model and optimzier
    MODEL = importlib.import_module('models.' +
                                    config.net.model)  # import network module
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    num_input_channel = int(
        config.data.use_color) * 3 + int(not config.data.no_height) * 1

    if config.net.model == 'boxnet':
        Detector = MODEL.BoxNet
    else:
        Detector = MODEL.VoteNet

    net = Detector(num_class=dataset_config.num_class,
                   num_heading_bin=dataset_config.num_heading_bin,
                   num_size_cluster=dataset_config.num_size_cluster,
                   mean_size_arr=dataset_config.mean_size_arr,
                   num_proposal=config.net.num_target,
                   input_feature_dim=num_input_channel,
                   vote_factor=config.net.vote_factor,
                   sampling=config.net.cluster_sampling,
                   backbone=config.net.backbone)

    if config.net.weights is not None:
        assert config.net.backbone == "sparseconv", "only support sparseconv"
        print('===> Loading weights: ' + config.net.weights)
        state = torch.load(
            config.net.weights,
            map_location=lambda s, l: default_restore_location(s, 'cpu'))
        model = net
        if config.net.is_train:
            model = net.backbone_net.net
        matched_weights = load_state_with_same_shape(model,
                                                     state['state_dict'])
        model_dict = model.state_dict()
        model_dict.update(matched_weights)
        model.load_state_dict(model_dict)

        # from pdb import set_trace; set_trace()

    net.to(device)

    if config.net.is_train:
        train(net, train_dataloader, test_dataloader, dataset_config, config)
    else:
        test(net, test_dataloader, dataset_config, config)
def main():
    config = get_config()
    if config.resume:
        json_config = json.load(open(config.resume + '/config.json', 'r'))
        json_config['resume'] = config.resume
        config = edict(json_config)

    if config.is_cuda and not torch.cuda.is_available():
        raise Exception("No GPU found")
    device = get_torch_device(config.is_cuda)

    logging.info('===> Configurations')
    dconfig = vars(config)
    for k in dconfig:
        logging.info('    {}: {}'.format(k, dconfig[k]))

    DatasetClass = load_dataset(config.dataset)
    if config.test_original_pointcloud:
        if not DatasetClass.IS_FULL_POINTCLOUD_EVAL:
            raise ValueError(
                'This dataset does not support full pointcloud evaluation.')

    if config.evaluate_original_pointcloud:
        if not config.return_transformation:
            raise ValueError(
                'Pointcloud evaluation requires config.return_transformation=true.'
            )

    if (config.return_transformation ^ config.evaluate_original_pointcloud):
        raise ValueError(
            'Rotation evaluation requires config.evaluate_original_pointcloud=true and '
            'config.return_transformation=true.')

    logging.info('===> Initializing dataloader')
    if config.is_train:
        train_data_loader = initialize_data_loader(
            DatasetClass,
            config,
            phase=config.train_phase,
            threads=config.threads,
            augment_data=True,
            shuffle=True,
            repeat=True,
            batch_size=config.batch_size,
            limit_numpoints=config.train_limit_numpoints)

        val_data_loader = initialize_data_loader(
            DatasetClass,
            config,
            threads=config.val_threads,
            phase=config.val_phase,
            augment_data=False,
            shuffle=True,
            repeat=False,
            batch_size=config.val_batch_size,
            limit_numpoints=False)
        if train_data_loader.dataset.NUM_IN_CHANNEL is not None:
            num_in_channel = train_data_loader.dataset.NUM_IN_CHANNEL
        else:
            num_in_channel = 3  # RGB color

        num_labels = train_data_loader.dataset.NUM_LABELS
    else:
        test_data_loader = initialize_data_loader(
            DatasetClass,
            config,
            threads=config.threads,
            phase=config.test_phase,
            augment_data=False,
            shuffle=False,
            repeat=False,
            batch_size=config.test_batch_size,
            limit_numpoints=False)
        if test_data_loader.dataset.NUM_IN_CHANNEL is not None:
            num_in_channel = test_data_loader.dataset.NUM_IN_CHANNEL
        else:
            num_in_channel = 3  # RGB color

        num_labels = test_data_loader.dataset.NUM_LABELS

    logging.info('===> Building model')
    NetClass = load_model(config.model)
    if config.wrapper_type == 'None':
        model = NetClass(num_in_channel, num_labels, config)
        logging.info('===> Number of trainable parameters: {}: {}'.format(
            NetClass.__name__, count_parameters(model)))
    else:
        wrapper = load_wrapper(config.wrapper_type)
        model = wrapper(NetClass, num_in_channel, num_labels, config)
        logging.info('===> Number of trainable parameters: {}: {}'.format(
            wrapper.__name__ + NetClass.__name__, count_parameters(model)))

    logging.info(model)
    model = model.to(device)

    if config.weights == 'modelzoo':  # Load modelzoo weights if possible.
        logging.info('===> Loading modelzoo weights')
        model.preload_modelzoo()

    # Load weights if specified by the parameter.
    elif config.weights.lower() != 'none':
        logging.info('===> Loading weights: ' + config.weights)
        state = torch.load(config.weights)
        if config.weights_for_inner_model:
            model.model.load_state_dict(state['state_dict'])
        else:
            if config.lenient_weight_loading:
                matched_weights = load_state_with_same_shape(
                    model, state['state_dict'])
                model_dict = model.state_dict()
                model_dict.update(matched_weights)
                model.load_state_dict(model_dict)
            else:
                model.load_state_dict(state['state_dict'])

    if config.is_train:
        train(model, train_data_loader, val_data_loader, config)
    else:
        test(model, test_data_loader, config)
Beispiel #23
0
def main():
    config = get_config()
    ch = logging.StreamHandler(sys.stdout)
    logging.getLogger().setLevel(logging.INFO)
    formatter = logging.Formatter(
        '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler(
        os.path.join(config.log_dir, './model.log'))
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logging.basicConfig(format=os.uname()[1].split('.')[0] +
                        ' %(asctime)s %(message)s',
                        datefmt='%m/%d %H:%M:%S',
                        handlers=[ch, file_handler])

    if config.test_config:
        # When using the test_config, reload and overwrite it, so should keep some configs
        val_bs = config.val_batch_size
        is_export = config.is_export

        json_config = json.load(open(config.test_config, 'r'))
        json_config['is_train'] = False
        json_config['weights'] = config.weights
        json_config['multiprocess'] = False
        json_config['log_dir'] = config.log_dir
        json_config['val_threads'] = config.val_threads
        json_config['submit'] = config.submit
        config = edict(json_config)

        config.val_batch_size = val_bs
        config.is_export = is_export
        config.is_train = False
        sys.path.append(config.log_dir)
        # from local_models import load_model
    else:
        '''bakup files'''
        if not os.path.exists(os.path.join(config.log_dir, 'models')):
            os.mkdir(os.path.join(config.log_dir, 'models'))
        for filename in os.listdir('./models'):
            if ".py" in filename:  # donnot cp the init file since it will raise import error
                shutil.copy(os.path.join("./models", filename),
                            os.path.join(config.log_dir, 'models'))
            elif 'modules' in filename:
                # copy the moduls folder also
                if os.path.exists(
                        os.path.join(config.log_dir, 'models/modules')):
                    shutil.rmtree(
                        os.path.join(config.log_dir, 'models/modules'))
                shutil.copytree(os.path.join('./models', filename),
                                os.path.join(config.log_dir, 'models/modules'))

        shutil.copy('./main.py', config.log_dir)
        shutil.copy('./config.py', config.log_dir)
        shutil.copy('./lib/train.py', config.log_dir)
        shutil.copy('./lib/test.py', config.log_dir)

    if config.resume == 'True':
        new_iter_size = config.max_iter
        new_bs = config.batch_size
        config.resume = config.log_dir
        json_config = json.load(open(config.resume + '/config.json', 'r'))
        json_config['resume'] = config.resume
        config = edict(json_config)
        config.weights = os.path.join(
            config.log_dir, 'weights.pth')  # use the pre-trained weights
        logging.info('==== resuming from {}, Total {} ======'.format(
            config.max_iter, new_iter_size))
        config.max_iter = new_iter_size
        config.batch_size = new_bs
    else:
        config.resume = None

    if config.is_cuda and not torch.cuda.is_available():
        raise Exception("No GPU found")
    gpu_list = range(config.num_gpu)
    device = get_torch_device(config.is_cuda)

    # torch.set_num_threads(config.threads)
    # torch.manual_seed(config.seed)
    # if config.is_cuda:
    #       torch.cuda.manual_seed(config.seed)

    logging.info('===> Configurations')
    dconfig = vars(config)
    for k in dconfig:
        logging.info('      {}: {}'.format(k, dconfig[k]))

    DatasetClass = load_dataset(config.dataset)
    logging.info('===> Initializing dataloader')

    setup_seed(2021)
    """
    ---- Setting up train, val, test dataloaders ----
    Supported datasets:
    - ScannetSparseVoxelizationDataset
    - ScannetDataset
    - SemanticKITTI
    """

    point_scannet = False
    if config.is_train:

        if config.dataset == 'ScannetSparseVoxelizationDataset':
            point_scannet = False
            train_data_loader = initialize_data_loader(
                DatasetClass,
                config,
                phase=config.train_phase,
                threads=config.threads,
                augment_data=True,
                elastic_distortion=config.train_elastic_distortion,
                shuffle=True,
                # shuffle=False,   # DEBUG ONLY!!!
                repeat=True,
                # repeat=False,
                batch_size=config.batch_size,
                limit_numpoints=config.train_limit_numpoints)

            val_data_loader = initialize_data_loader(
                DatasetClass,
                config,
                threads=config.val_threads,
                phase=config.val_phase,
                augment_data=False,
                elastic_distortion=config.test_elastic_distortion,
                shuffle=False,
                repeat=False,
                batch_size=config.val_batch_size,
                limit_numpoints=False)

        elif config.dataset == 'ScannetDataset':
            val_DatasetClass = load_dataset(
                'ScannetDatasetWholeScene_evaluation')
            point_scannet = True

            # collate_fn = t.cfl_collate_fn_factory(False) # no limit num-points
            trainset = DatasetClass(
                root=
                '/data/eva_share_users/zhaotianchen/scannet/raw/scannet_pickles',
                npoints=config.num_points,
                # split='debug',
                split='train',
                with_norm=False,
            )
            train_data_loader = torch.utils.data.DataLoader(
                dataset=trainset,
                num_workers=config.threads,
                # num_workers=0,  # for loading big pth file, should use single-thread
                batch_size=config.batch_size,
                # collate_fn=collate_fn, # input points, should not have collate-fn
                worker_init_fn=_init_fn,
                sampler=InfSampler(trainset, True))  # shuffle=True

            valset = val_DatasetClass(
                root=
                '/data/eva_share_users/zhaotianchen/scannet/raw/scannet_pickles',
                scene_list_dir=
                '/data/eva_share_users/zhaotianchen/scannet/raw/metadata',
                # split='debug',
                split='eval',
                block_points=config.num_points,
                with_norm=False,
                delta=1.0,
            )
            val_data_loader = torch.utils.data.DataLoader(
                dataset=valset,
                # num_workers=config.threads,
                num_workers=
                0,  # for loading big pth file, should use single-thread
                batch_size=config.val_batch_size,
                # collate_fn=collate_fn, # input points, should not have collate-fn
                worker_init_fn=_init_fn)

        elif config.dataset == "SemanticKITTI":
            point_scannet = False
            dataset = SemanticKITTI(root=config.semantic_kitti_path,
                                    num_points=None,
                                    voxel_size=config.voxel_size,
                                    sample_stride=config.sample_stride,
                                    submit=False)
            collate_fn_factory = t.cfl_collate_fn_factory
            train_data_loader = torch.utils.data.DataLoader(
                dataset['train'],
                batch_size=config.batch_size,
                sampler=InfSampler(dataset['train'],
                                   shuffle=True),  # shuffle=true, repeat=true
                num_workers=config.threads,
                pin_memory=True,
                collate_fn=collate_fn_factory(config.train_limit_numpoints))

            val_data_loader = torch.utils.data.DataLoader(  # shuffle=false, repeat=false
                dataset['test'],
                batch_size=config.batch_size,
                num_workers=config.val_threads,
                pin_memory=True,
                collate_fn=t.cfl_collate_fn_factory(False))
        elif config.dataset == "S3DIS":
            trainset = S3DIS(
                config,
                train=True,
            )
            valset = S3DIS(
                config,
                train=False,
            )
            train_data_loader = torch.utils.data.DataLoader(
                trainset,
                batch_size=config.batch_size,
                sampler=InfSampler(trainset,
                                   shuffle=True),  # shuffle=true, repeat=true
                num_workers=config.threads,
                pin_memory=True,
                collate_fn=t.cfl_collate_fn_factory(
                    config.train_limit_numpoints))

            val_data_loader = torch.utils.data.DataLoader(  # shuffle=false, repeat=false
                valset,
                batch_size=config.batch_size,
                num_workers=config.val_threads,
                pin_memory=True,
                collate_fn=t.cfl_collate_fn_factory(False))
        elif config.dataset == 'Nuscenes':
            config.xyz_input = False
            # todo:
            trainset = Nuscenes(
                config,
                train=True,
            )
            valset = Nuscenes(
                config,
                train=False,
            )
            train_data_loader = torch.utils.data.DataLoader(
                trainset,
                batch_size=config.batch_size,
                sampler=InfSampler(trainset,
                                   shuffle=True),  # shuffle=true, repeat=true
                num_workers=config.threads,
                pin_memory=True,
                # collate_fn=t.collate_fn_BEV,    # used when cylinder voxelize
                collate_fn=t.cfl_collate_fn_factory(False))

            val_data_loader = torch.utils.data.DataLoader(  # shuffle=false, repeat=false
                valset,
                batch_size=config.batch_size,
                num_workers=config.val_threads,
                pin_memory=True,
                # collate_fn=t.collate_fn_BEV,
                collate_fn=t.cfl_collate_fn_factory(False))
        else:
            print('Dataset {} not supported').format(config.dataset)
            raise NotImplementedError

        # Setting up num_in_channel and num_labels
        if train_data_loader.dataset.NUM_IN_CHANNEL is not None:
            num_in_channel = train_data_loader.dataset.NUM_IN_CHANNEL
        else:
            num_in_channel = 3

        num_labels = train_data_loader.dataset.NUM_LABELS

        # it = iter(train_data_loader)
        # for _ in range(100):
        # data = it.__next__()
        # print(data)

    else:  # not config.is_train

        val_DatasetClass = load_dataset('ScannetDatasetWholeScene_evaluation')

        if config.dataset == 'ScannetSparseVoxelizationDataset':

            if config.is_export:  # when export, we need to export the train results too
                train_data_loader = initialize_data_loader(
                    DatasetClass,
                    config,
                    phase=config.train_phase,
                    threads=config.threads,
                    augment_data=True,
                    elastic_distortion=config.
                    train_elastic_distortion,  # DEBUG: not sure about this
                    shuffle=False,
                    repeat=False,
                    batch_size=config.batch_size,
                    limit_numpoints=config.train_limit_numpoints)

                # the valid like, no aug data
                # train_data_loader = initialize_data_loader(
                # DatasetClass,
                # config,
                # threads=config.val_threads,
                # phase=config.train_phase,
                # augment_data=False,
                # elastic_distortion=config.test_elastic_distortion,
                # shuffle=False,
                # repeat=False,
                # batch_size=config.val_batch_size,
                # limit_numpoints=False)

            val_data_loader = initialize_data_loader(
                DatasetClass,
                config,
                threads=config.val_threads,
                phase=config.val_phase,
                augment_data=False,
                elastic_distortion=config.test_elastic_distortion,
                shuffle=False,
                repeat=False,
                batch_size=config.val_batch_size,
                limit_numpoints=False)

            if val_data_loader.dataset.NUM_IN_CHANNEL is not None:
                num_in_channel = val_data_loader.dataset.NUM_IN_CHANNEL
            else:
                num_in_channel = 3

            num_labels = val_data_loader.dataset.NUM_LABELS

        elif config.dataset == 'ScannetDataset':
            '''when using scannet-point, use val instead of test'''

            point_scannet = True
            valset = val_DatasetClass(
                root=
                '/data/eva_share_users/zhaotianchen/scannet/raw/scannet_pickles',
                scene_list_dir=
                '/data/eva_share_users/zhaotianchen/scannet/raw/metadata',
                split='eval',
                block_points=config.num_points,
                delta=1.0,
                with_norm=False,
            )
            val_data_loader = torch.utils.data.DataLoader(
                dataset=valset,
                # num_workers=config.threads,
                num_workers=
                0,  # for loading big pth file, should use single-thread
                batch_size=config.val_batch_size,
                # collate_fn=collate_fn, # input points, should not have collate-fn
                worker_init_fn=_init_fn,
            )

            num_labels = val_data_loader.dataset.NUM_LABELS
            num_in_channel = 3

        elif config.dataset == "SemanticKITTI":
            dataset = SemanticKITTI(root=config.semantic_kitti_path,
                                    num_points=None,
                                    voxel_size=config.voxel_size,
                                    submit=config.submit)
            val_data_loader = torch.utils.data.DataLoader(  # shuffle=false, repeat=false
                dataset['test'],
                batch_size=config.val_batch_size,
                num_workers=config.val_threads,
                pin_memory=True,
                collate_fn=t.cfl_collate_fn_factory(False))
            num_in_channel = 4
            num_labels = 19

        elif config.dataset == 'S3DIS':
            config.xyz_input = False

            trainset = S3DIS(
                config,
                train=True,
            )
            valset = S3DIS(
                config,
                train=False,
            )
            train_data_loader = torch.utils.data.DataLoader(
                trainset,
                batch_size=config.batch_size,
                sampler=InfSampler(trainset,
                                   shuffle=True),  # shuffle=true, repeat=true
                num_workers=config.threads,
                pin_memory=True,
                collate_fn=t.cfl_collate_fn_factory(
                    config.train_limit_numpoints))

            val_data_loader = torch.utils.data.DataLoader(  # shuffle=false, repeat=false
                valset,
                batch_size=config.batch_size,
                num_workers=config.val_threads,
                pin_memory=True,
                collate_fn=t.cfl_collate_fn_factory(False))
            num_in_channel = 9
            num_labels = 13
        elif config.dataset == 'Nuscenes':
            config.xyz_input = False
            trainset = Nuscenes(
                config,
                train=True,
            )
            valset = Nuscenes(
                config,
                train - False,
            )
            train_data_loader = torch.utils.data.DataLoader(
                trainset,
                batch_size=config.batch_size,
                sampler=InfSampler(trainset,
                                   shuffle=True),  # shuffle=true, repeat=true
                num_workers=config.threads,
                pin_memory=True,
                # collate_fn=t.collate_fn_BEV,
                collate_fn=t.cfl_collate_fn_factory(False))

            val_data_loader = torch.utils.data.DataLoader(  # shuffle=false, repeat=false
                valset,
                batch_size=config.batch_size,
                num_workers=config.val_threads,
                pin_memory=True,
                # collate_fn=t.collate_fn_BEV,
                collate_fn=t.cfl_collate_fn_factory(False))
            num_in_channel = 5
            num_labels = 16
        else:
            print('Dataset {} not supported').format(config.dataset)
            raise NotImplementedError

    logging.info('===> Building model')

    # if config.model == 'PointTransformer' or config.model == 'MixedTransformer':
    if config.model == 'PointTransformer':
        config.pure_point = True

    NetClass = load_model(config.model)
    if config.pure_point:
        model = NetClass(config,
                         num_class=num_labels,
                         N=config.num_points,
                         normal_channel=num_in_channel)
    else:
        if config.model == 'MixedTransformer':
            model = NetClass(config,
                             num_class=num_labels,
                             N=config.num_points,
                             normal_channel=num_in_channel)
        elif config.model == 'MinkowskiVoxelTransformer':
            model = NetClass(config, num_in_channel, num_labels)
        elif config.model == 'MinkowskiTransformerNet':
            model = NetClass(config, num_in_channel, num_labels)
        elif "Res" in config.model:
            model = NetClass(num_in_channel, num_labels, config)
        else:
            model = NetClass(num_in_channel, num_labels, config)

    logging.info('===> Number of trainable parameters: {}: {}M'.format(
        NetClass.__name__,
        count_parameters(model) / 1e6))
    if hasattr(model, "block1"):
        if hasattr(model.block1[0], 'h'):
            h = model.block1[0].h
            vec_dim = model.block1[0].vec_dim
        else:
            h = None
            vec_dim = None
    else:
        h = None
        vec_dim = None
    # logging.info('===> Model Args:\n PLANES: {} \n LAYERS: {}\n HEADS: {}\n Vec-dim: {}\n'.format(model.PLANES, model.LAYERS, h, vec_dim))
    logging.info(model)

    # Set the number of threads
    # ME.initialize_nthreads(12, D=3)

    model = model.to(device)

    if config.weights == 'modelzoo':  # Load modelzoo weights if possible.
        logging.info('===> Loading modelzoo weights')
        model.preload_modelzoo()
    # Load weights if specified by the parameter.
    elif config.weights.lower() != 'none':
        logging.info('===> Loading weights: ' + config.weights)
        state = torch.load(config.weights)
        # delete the keys containing the 'attn' since it raises size mismatch
        d_ = {
            k: v
            for k, v in state['state_dict'].items() if '_map' not in k
        }  # debug: sometiems model conmtains 'map_qk' which is not right for naming a module, since 'map' are always buffers
        d = {}
        for k in d_.keys():
            if 'module.' in k:
                d[k.replace('module.', '')] = d_[k]
            else:
                d[k] = d_[k]
        # del d_

        if config.weights_for_inner_model:
            model.model.load_state_dict(d)
        else:
            if config.lenient_weight_loading:
                matched_weights = load_state_with_same_shape(
                    model, state['state_dict'])
                model_dict = model.state_dict()
                model_dict.update(matched_weights)
                model.load_state_dict(model_dict)
            else:
                model.load_state_dict(d, strict=True)

    if config.is_debug:
        check_data(model, train_data_loader, val_data_loader, config)
        return None
    elif config.is_train:
        if hasattr(config, 'distill') and config.distill:
            assert point_scannet is not True  # only support whole scene for no
            train_distill(model, train_data_loader, val_data_loader, config)
        if config.multiprocess:
            if point_scannet:
                raise NotImplementedError
            else:
                train_mp(NetClass, train_data_loader, val_data_loader, config)
        else:
            if point_scannet:
                train_point(model, train_data_loader, val_data_loader, config)
            else:
                train(model, train_data_loader, val_data_loader, config)
    elif config.is_export:
        if point_scannet:
            raise NotImplementedError
        else:  # only support the whole-scene-style for now
            test(model,
                 train_data_loader,
                 config,
                 save_pred=True,
                 split='train')
            test(model, val_data_loader, config, save_pred=True, split='val')
    else:
        assert config.multiprocess == False
        # if test for submission, make a submit directory at current directory
        submit_dir = os.path.join(os.getcwd(), 'submit', 'sequences')
        if config.submit and not os.path.exists(submit_dir):
            os.makedirs(submit_dir)
            print("Made submission directory: " + submit_dir)
        if point_scannet:
            test_points(model, val_data_loader, config)
        else:
            test(model, val_data_loader, config, submit_dir=submit_dir)