def train(conf):
    # load network model
    models = utils.get_model_module(conf.model_version)

    # check if training run already exists. If so, delete it.
    if os.path.exists(os.path.join(conf.log_path, conf.exp_name)) or \
       os.path.exists(os.path.join(conf.model_path, conf.exp_name)):
           print(f'A Training run named {conf.exp_name} already exists!!')
           sys.exit()
         # response = input('A training run named "%s" already exists, overwrite? (y/n) ' % (conf.exp_name))
         # if response != 'y':
             # sys.exit()
    if os.path.exists(os.path.join(conf.log_path, conf.exp_name)):
        shutil.rmtree(os.path.join(conf.log_path, conf.exp_name))
    if os.path.exists(os.path.join(conf.model_path, conf.exp_name)):
        shutil.rmtree(os.path.join(conf.model_path, conf.exp_name))

    # create directories for this run
    os.makedirs(os.path.join(conf.model_path, conf.exp_name))
    os.makedirs(os.path.join(conf.log_path, conf.exp_name))

    # file log
    flog = open(os.path.join(conf.log_path, conf.exp_name, 'train.log'), 'w')

    # set training device
    device = torch.device(conf.device)
    print(f'Using device: {conf.device}')
    flog.write(f'Using device: {conf.device}\n')

    # log the object category information
    print(f'Object Category: {conf.category}')
    flog.write(f'Object Category: {conf.category}\n')

    # control randomness
    if conf.seed < 0:
        conf.seed = random.randint(1, 10000)
    print("Random Seed: %d" % (conf.seed))
    flog.write(f'Random Seed: {conf.seed}\n')
    random.seed(conf.seed)
    np.random.seed(conf.seed)
    torch.manual_seed(conf.seed)

    # save config
    torch.save(conf, os.path.join(conf.model_path, conf.exp_name, 'conf.pth'))
    print('\n\n Configuration:\n', vars(conf), '\n\n')

    # create models
    encoder = models.RecursiveEncoder(conf, variational= not conf.non_variational, probabilistic=not conf.non_probabilistic)
    decoder = models.RecursiveDecoder(conf)
    models = [encoder, decoder]
    model_names = ['encoder', 'decoder']

    # create optimizers
    encoder_opt = torch.optim.Adam(encoder.parameters(), lr=conf.lr)
    decoder_opt = torch.optim.Adam(decoder.parameters(), lr=conf.lr)
    optimizers = [encoder_opt, decoder_opt]
    optimizer_names = ['encoder', 'decoder']

    # learning rate scheduler
    if conf.scheduler == 'StepLR':
        encoder_scheduler = torch.optim.lr_scheduler.StepLR(encoder_opt, \
                step_size=conf.lr_decay_every, gamma=conf.lr_decay_by)
        decoder_scheduler = torch.optim.lr_scheduler.StepLR(decoder_opt, \
                step_size=conf.lr_decay_every, gamma=conf.lr_decay_by)
    elif conf.scheduler == 'ReduceLROnPlateau':
        encoder_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(encoder_opt, mode='min', \
                factor=conf.lr_decay_factor, patience=2)
        decoder_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(decoder_opt, mode='min',\
                factor=conf.lr_decay_factor, patience=2)

    # create training and validation datasets and data loaders
    
    
    data_features = ['uxid', 'object']
    # train_dataset = PartNetDataset(conf.data_path, conf.train_dataset, data_features, \
    #        load_geo=conf.load_geo)
    # valdt_dataset = PartNetDataset(conf.data_path, conf.val_dataset, data_features, \
    #        load_geo=conf.load_geo)
    
    DatasetClass = globals()[config.DatasetClass] 
    print('Using dataset:', DatasetClass)   
    train_dataset = DatasetClass(conf.data_path, conf.train_dataset, ['uxid', 'object'],
                                is_train=True, permute=(conf.permutations > 1), n_permutes=(conf.permutations)                                )
    
    valdt_dataset = DatasetClass(conf.data_path, config.val_dataset, ['uxid', 'object'],
                                is_train=False, permute=False, n_permutes=1)
        
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=conf.batch_size, \
            shuffle=True, collate_fn=lambda x: list(zip(*x)))
    valdt_dataloader = torch.utils.data.DataLoader(valdt_dataset, batch_size=conf.batch_size, \
            shuffle=True, collate_fn=lambda x: list(zip(*x)))

    # create logs
    if not conf.no_console_log:    
        #header = '     Time    Epoch     Dataset    Iteration    Progress(%)       LR       BoxLoss   StructLoss   EdgeExists  KLDivLoss   SymLoss    AdjLoss  AnchorLoss  TotalLoss'
        header = '     Time    Epoch     Dataset    Iteration     Progress(%)       LR       BoxLoss   LeafLoss     NodeExists      SemLoss   ChildCountLoss    KLDivLoss     TotalLoss'
    
    if not conf.no_tb_log:
        # https://github.com/lanpa/tensorboard-pytorch
        from tensorboardX import SummaryWriter
        train_writer = SummaryWriter(os.path.join(conf.log_path, conf.exp_name, 'train'))
        valdt_writer = SummaryWriter(os.path.join(conf.log_path, conf.exp_name, 'val'))

    # send parameters to device
    for m in models:
        m.to(device)
    for o in optimizers:
        utils.optimizer_to_device(o, device)

    # start training
    print("Starting training ...... ")
    flog.write('Starting training ......\n')

    start_time = time.time()

    last_checkpoint_step = None
    last_train_console_log_step, last_valdt_console_log_step = None, None
    train_num_batch, valdt_num_batch = len(train_dataloader), len(valdt_dataloader)

    # train for every epoch
    for epoch in range(conf.epochs):
        global train_stats
        global valdt_stats
        
        train_stats = Statistics()
        valdt_stats = Statistics()

        if not conf.no_console_log:
            print(f'training run {conf.exp_name}')
            flog.write(f'training run {conf.exp_name}\n')
            print(header)
            flog.write(header+'\n')

        train_batches = enumerate(train_dataloader, 0)
        valdt_batches = enumerate(valdt_dataloader, 0)

        train_fraction_done, valdt_fraction_done = 0.0, 0.0
        valdt_batch_ind = -1

        # train for every batch
        for train_batch_ind, batch in train_batches:
            #tic_compl_batch = time.time()

            train_fraction_done = (train_batch_ind + 1) / train_num_batch
            train_step = epoch * train_num_batch + train_batch_ind

            log_console = not conf.no_console_log and (last_train_console_log_step is None or \
                    train_step - last_train_console_log_step >= conf.console_log_interval)
            if log_console:
                last_train_console_log_step = train_step

            # set models to training mode
            for m in models:
                m.train()

            # forward pass (including logging)
            total_loss = forward(
                batch=batch, data_features=data_features, encoder=encoder, decoder=decoder, device=device, conf=conf,
                is_valdt=False, step=train_step, epoch=epoch, batch_ind=train_batch_ind, num_batch=train_num_batch, start_time=start_time,
                log_console=log_console, log_tb=not conf.no_tb_log, tb_writer=train_writer,
                lr=encoder_opt.param_groups[0]['lr'], flog=flog)

            # optimize one step
            encoder_opt.zero_grad()
            decoder_opt.zero_grad()
            total_loss.backward()
            encoder_opt.step()
            decoder_opt.step()

            train_stats.add('total_loss', float(total_loss.item()), len(batch[0]) )
            del total_loss

            # save checkpoint
            with torch.no_grad():
                if last_checkpoint_step is None or \
                        train_step - last_checkpoint_step >= conf.checkpoint_interval:
                    print("Saving checkpoint ...... ", end='', flush=True)
                    flog.write("Saving checkpoint ...... ")
                    utils.save_checkpoint(
                        models=models, model_names=model_names, dirname=os.path.join(conf.model_path, conf.exp_name),
                        epoch=epoch, prepend_epoch=True, optimizers=optimizers, optimizer_names=model_names)
                    print("DONE")
                    flog.write("DONE\n")
                    last_checkpoint_step = train_step

            # validate one batch
            # while valdt_fraction_done <= train_fraction_done and valdt_batch_ind+1 < valdt_num_batch:
            #     valdt_batch_ind, batch = next(valdt_batches)

            #     valdt_fraction_done = (valdt_batch_ind + 1) / valdt_num_batch
            #     valdt_step = (epoch + valdt_fraction_done) * train_num_batch - 1

            #     log_console = not conf.no_console_log and (last_valdt_console_log_step is None or \
            #             valdt_step - last_valdt_console_log_step >= conf.console_log_interval)
            #     if log_console:
            #         last_valdt_console_log_step = valdt_step

            #     # set models to evaluation mode
            #     for m in models:
            #         m.eval()

            #     with torch.no_grad():
            #         # forward pass (including logging)
            #         __ = forward(
            #             batch=batch, data_features=data_features, encoder=encoder, decoder=decoder, device=device, conf=conf,
            #             is_valdt=True, step=valdt_step, epoch=epoch, batch_ind=valdt_batch_ind, num_batch=valdt_num_batch, start_time=start_time,
            #             log_console=log_console, log_tb=not conf.no_tb_log, tb_writer=valdt_writer,
            #             lr=encoder_opt.param_groups[0]['lr'], flog=flog)

        # Validate in every batch 
        with torch.no_grad():
            for valdt_batch_ind , batch in valdt_batches:
                valdt_fraction_done = (valdt_batch_ind + 1) / valdt_num_batch
                valdt_step = epoch * valdt_num_batch + valdt_batch_ind
    
                log_console = not conf.no_console_log and (last_valdt_console_log_step is None or \
                        valdt_step - last_valdt_console_log_step >= conf.console_log_interval)
                if log_console:
                    last_valdt_console_log_step = valdt_step
    
                # set models to training mode
                for m in models:
                    m.eval()
    
                # forward pass (including logging)
    
                total_loss = forward(
                    batch=batch, data_features=data_features, encoder=encoder, decoder=decoder, device=device, conf=conf,
                    is_valdt=True, step=valdt_step, epoch=epoch, batch_ind=valdt_batch_ind, num_batch=valdt_num_batch, start_time=start_time,
                    log_console=log_console, log_tb=not conf.no_tb_log, tb_writer=valdt_writer,
                    lr=encoder_opt.param_groups[0]['lr'], flog=flog)

                valdt_stats.add('total_loss', float(total_loss.item()), len(batch[0]) ) 

                    
        valid_loss = valdt_stats.mean(conf.metric)    
        
        if conf.scheduler == 'StepLR':
            encoder_scheduler.step()
            decoder_scheduler.step()
        elif conf.scheduler == 'ReduceLROnPlateau':
            encoder_scheduler.step(valid_loss)
            decoder_scheduler.step(valid_loss)
            #print(f'1 complete batch update, Elsped time: {time.time()-tic_compl_batch:.2f}')
    
    # save the final models
    print("Saving final checkpoint ...... ", end='', flush=True)
    flog.write("Saving final checkpoint ...... ")
    utils.save_checkpoint(
        models=models, model_names=model_names, dirname=os.path.join(conf.model_path, conf.exp_name),
        epoch=epoch, prepend_epoch=False, optimizers=optimizers, optimizer_names=optimizer_names)
    print("DONE")
    flog.write("DONE\n")

    flog.close()
Пример #2
0
def train(conf):
    # load network model
    models = utils.get_model_module(conf.model_version)

    # check if training run already exists. If so, delete it.
    if os.path.exists(os.path.join(conf.log_path, conf.exp_name)) or \
       os.path.exists(os.path.join(conf.model_path, conf.exp_name)):
        response = input(
            'A training run named "%s" already exists, overwrite? (y/n) ' %
            (conf.exp_name))
        if response != 'y':
            sys.exit()
    if os.path.exists(os.path.join(conf.log_path, conf.exp_name)):
        shutil.rmtree(os.path.join(conf.log_path, conf.exp_name))
    if os.path.exists(os.path.join(conf.model_path, conf.exp_name)):
        shutil.rmtree(os.path.join(conf.model_path, conf.exp_name))

    # create directories for this run
    os.makedirs(os.path.join(conf.model_path, conf.exp_name))
    os.makedirs(os.path.join(conf.log_path, conf.exp_name))

    # file log
    flog = open(os.path.join(conf.log_path, conf.exp_name, 'train.log'), 'w')

    # set training device
    device = torch.device(conf.device)
    print(f'Using device: {conf.device}')
    flog.write(f'Using device: {conf.device}\n')

    # log the object category information
    print(f'Object Category: {conf.category}')
    flog.write(f'Object Category: {conf.category}\n')

    # control randomness
    if conf.seed < 0:
        conf.seed = random.randint(1, 10000)
    print("Random Seed: %d" % (conf.seed))
    flog.write(f'Random Seed: {conf.seed}\n')
    random.seed(conf.seed)
    np.random.seed(conf.seed)
    torch.manual_seed(conf.seed)

    # create models
    encoder = models.PartEncoder(feat_len=conf.geo_feat_size,
                                 probabilistic=not conf.non_variational)
    decoder = models.PartDecoder(feat_len=conf.geo_feat_size,
                                 num_point=conf.num_point)
    models = [encoder, decoder]
    model_names = ['part_pc_encoder', 'part_pc_decoder']

    # create optimizers
    encoder_opt = torch.optim.Adam(encoder.parameters(),
                                   lr=conf.lr,
                                   weight_decay=conf.weight_decay)
    decoder_opt = torch.optim.Adam(decoder.parameters(),
                                   lr=conf.lr,
                                   weight_decay=conf.weight_decay)
    optimizers = [encoder_opt, decoder_opt]
    optimizer_names = ['part_pc_encoder', 'part_pc_decoder']

    # learning rate scheduler
    encoder_scheduler = torch.optim.lr_scheduler.StepLR(encoder_opt, \
            step_size=conf.lr_decay_every, gamma=conf.lr_decay_by)
    decoder_scheduler = torch.optim.lr_scheduler.StepLR(decoder_opt, \
            step_size=conf.lr_decay_every, gamma=conf.lr_decay_by)

    # create training and validation datasets and data loaders
    train_dataset = PartNetGeoDataset(conf.data_path,
                                      conf.train_dataset,
                                      use_local_frame=conf.use_local_frame)
    valdt_dataset = PartNetGeoDataset(conf.data_path,
                                      conf.val_dataset,
                                      use_local_frame=conf.use_local_frame)
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=conf.batch_size, \
            shuffle=True, collate_fn=utils.collate_feats)
    valdt_dataloader = torch.utils.data.DataLoader(valdt_dataset, batch_size=conf.batch_size, \
            shuffle=True, collate_fn=utils.collate_feats)

    # create logs
    if not conf.no_console_log:
        header = '     Time    Epoch    Dataset    Iteration    Progress(%)     LR      ReconLoss  KLDivLoss  TotalLoss'
    if not conf.no_tb_log:
        # https://github.com/lanpa/tensorboard-pytorch
        from tensorboardX import SummaryWriter
        train_writer = SummaryWriter(
            os.path.join(conf.log_path, conf.exp_name, 'train'))
        valdt_writer = SummaryWriter(
            os.path.join(conf.log_path, conf.exp_name, 'val'))

    # save config
    torch.save(conf, os.path.join(conf.model_path, conf.exp_name, 'conf.pth'))

    # send parameters to device
    for m in models:
        m.to(device)
    for o in optimizers:
        utils.optimizer_to_device(o, device)

    # start training
    print("Starting training ...... ")
    flog.write('Starting training ......\n')

    start_time = time.time()

    last_checkpoint_step = None
    last_train_console_log_step, last_valdt_console_log_step = None, None
    train_num_batch, valdt_num_batch = len(train_dataloader), len(
        valdt_dataloader)

    # train for every epoch
    for epoch in range(conf.epochs):
        if not conf.no_console_log:
            print(f'training run {conf.exp_name}')
            flog.write(f'training run {conf.exp_name}\n')
            print(header)
            flog.write(header + '\n')

        train_batches = enumerate(train_dataloader, 0)
        valdt_batches = enumerate(valdt_dataloader, 0)

        train_fraction_done, valdt_fraction_done = 0.0, 0.0
        valdt_batch_ind = -1

        # train for every batch
        for train_batch_ind, batch in train_batches:
            train_fraction_done = (train_batch_ind + 1) / train_num_batch
            train_step = epoch * train_num_batch + train_batch_ind

            log_console = not conf.no_console_log and (last_train_console_log_step is None or \
                    train_step - last_train_console_log_step >= conf.console_log_interval)
            if log_console:
                last_train_console_log_step = train_step

            # set models to training mode
            for m in models:
                m.train()

            # forward pass (including logging)
            total_loss = forward(batch=batch,
                                 encoder=encoder,
                                 decoder=decoder,
                                 device=device,
                                 conf=conf,
                                 is_valdt=False,
                                 step=train_step,
                                 epoch=epoch,
                                 batch_ind=train_batch_ind,
                                 num_batch=train_num_batch,
                                 start_time=start_time,
                                 log_console=log_console,
                                 log_tb=not conf.no_tb_log,
                                 tb_writer=train_writer,
                                 lr=encoder_opt.param_groups[0]['lr'],
                                 flog=flog)

            # optimize one step
            encoder_scheduler.step()
            decoder_scheduler.step()
            encoder_opt.zero_grad()
            decoder_opt.zero_grad()
            total_loss.backward()
            encoder_opt.step()
            decoder_opt.step()

            # save checkpoint
            with torch.no_grad():
                if last_checkpoint_step is None or \
                        train_step - last_checkpoint_step >= conf.checkpoint_interval:
                    print("Saving checkpoint ...... ", end='', flush=True)
                    flog.write("Saving checkpoint ...... ")
                    utils.save_checkpoint(models=models,
                                          model_names=model_names,
                                          dirname=os.path.join(
                                              conf.model_path, conf.exp_name),
                                          epoch=epoch,
                                          prepend_epoch=True,
                                          optimizers=optimizers,
                                          optimizer_names=model_names)
                    print("DONE")
                    flog.write("DONE\n")
                    last_checkpoint_step = train_step

            # validate one batch
            while valdt_fraction_done <= train_fraction_done and valdt_batch_ind + 1 < valdt_num_batch:
                valdt_batch_ind, batch = next(valdt_batches)

                valdt_fraction_done = (valdt_batch_ind + 1) / valdt_num_batch
                valdt_step = (epoch +
                              valdt_fraction_done) * train_num_batch - 1

                log_console = not conf.no_console_log and (last_valdt_console_log_step is None or \
                        valdt_step - last_valdt_console_log_step >= conf.console_log_interval)
                if log_console:
                    last_valdt_console_log_step = valdt_step

                # set models to evaluation mode
                for m in models:
                    m.eval()

                with torch.no_grad():
                    # forward pass (including logging)
                    __ = forward(batch=batch,
                                 encoder=encoder,
                                 decoder=decoder,
                                 device=device,
                                 conf=conf,
                                 is_valdt=True,
                                 step=valdt_step,
                                 epoch=epoch,
                                 batch_ind=valdt_batch_ind,
                                 num_batch=valdt_num_batch,
                                 start_time=start_time,
                                 log_console=log_console,
                                 log_tb=not conf.no_tb_log,
                                 tb_writer=valdt_writer,
                                 lr=encoder_opt.param_groups[0]['lr'],
                                 flog=flog)

    # save the final models
    print("Saving final checkpoint ...... ", end='', flush=True)
    flog.write('Saving final checkpoint ...... ')
    utils.save_checkpoint(models=models,
                          model_names=model_names,
                          dirname=os.path.join(conf.model_path, conf.exp_name),
                          epoch=epoch,
                          prepend_epoch=False,
                          optimizers=optimizers,
                          optimizer_names=optimizer_names)
    print("DONE")
    flog.write("DONE\n")

    flog.close()
Пример #3
0
def train(conf):
    # create training and validation datasets and data loaders
    data_features = ['img', 'pts', 'ins_one_hot' , 'box_size', 'total_parts_cnt' , 'similar_parts_cnt', 'mask' ,'shape_id', 'view_id']
    
    train_dataset = PartNetShapeDataset(conf.category, conf.data_dir, data_features, data_split="train", \
            max_num_mask = conf.max_num_parts, max_num_similar_parts=conf.max_num_similar_parts, img_size=conf.img_size, on_kaichun_machine=conf.on_kaichun_machine)
    utils.printout(conf.flog, str(train_dataset))
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=conf.batch_size, shuffle=True, pin_memory=True, \
            num_workers=conf.num_workers, drop_last=True, collate_fn=utils.collate_feats_with_none, worker_init_fn=utils.worker_init_fn)
    
    val_dataset = PartNetShapeDataset(conf.category, conf.data_dir, data_features, data_split="val", \
            max_num_mask = conf.max_num_parts, max_num_similar_parts=conf.max_num_similar_parts, img_size=conf.img_size, on_kaichun_machine=conf.on_kaichun_machine)
    utils.printout(conf.flog, str(val_dataset))
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=conf.batch_size, shuffle=False, pin_memory=True, \
            num_workers=0, drop_last=True, collate_fn=utils.collate_feats_with_none, worker_init_fn=utils.worker_init_fn)

    # load network model
    model_def = utils.get_model_module(conf.model_version)

    # create models
    network = model_def.Network(conf, train_dataset.get_part_count())
    utils.printout(conf.flog, '\n' + str(network) + '\n')

    models = [network]
    model_names = ['network']

    # create optimizers
    network_opt = torch.optim.Adam(network.parameters(), lr=conf.lr, weight_decay=conf.weight_decay)
    optimizers = [network_opt]
    optimizer_names = ['network_opt']

    # learning rate scheduler
    network_lr_scheduler = torch.optim.lr_scheduler.StepLR(network_opt, step_size=conf.lr_decay_every, gamma=conf.lr_decay_by)

    # create logs
    if not conf.no_console_log:
        header = '     Time    Epoch     Dataset    Iteration    Progress(%)       LR    CenterLoss    QuatLoss   TotalLoss'
    if not conf.no_tb_log:
        # https://github.com/lanpa/tensorboard-pytorch
        from tensorboardX import SummaryWriter
        train_writer = SummaryWriter(os.path.join(conf.exp_dir, 'train'))
        val_writer = SummaryWriter(os.path.join(conf.exp_dir, 'val'))

    # send parameters to device
    for m in models:
        m.to(conf.device)
    for o in optimizers:
        utils.optimizer_to_device(o, conf.device)

    # start training
    start_time = time.time()

    last_checkpoint_step = None
    last_train_console_log_step, last_val_console_log_step = None, None
    train_num_batch = len(train_dataloader)
    val_num_batch = len(val_dataloader)

    # train for every epoch
    for epoch in range(conf.epochs):
        if not conf.no_console_log:
            utils.printout(conf.flog, f'training run {conf.exp_name}')
            utils.printout(conf.flog, header)

        train_batches = enumerate(train_dataloader, 0)
        val_batches = enumerate(val_dataloader, 0)
        train_fraction_done = 0.0
        val_fraction_done = 0.0
        val_batch_ind = -1

        # train for every batch
        for train_batch_ind, batch in train_batches:
            train_fraction_done = (train_batch_ind + 1) / train_num_batch
            train_step = epoch * train_num_batch + train_batch_ind

            log_console = not conf.no_console_log and (last_train_console_log_step is None or \
                    train_step - last_train_console_log_step >= conf.console_log_interval)
            if log_console:
                last_train_console_log_step = train_step

            # set models to training mode
            for m in models:
                m.train()

            # forward pass (including logging)
            total_loss = forward(batch=batch, data_features=data_features, network=network, conf=conf, is_val=False, \
                    step=train_step, epoch=epoch, batch_ind=train_batch_ind, num_batch=train_num_batch, start_time=start_time, \
                    log_console=log_console, log_tb=not conf.no_tb_log, tb_writer=train_writer, lr=network_opt.param_groups[0]['lr'])

            if total_loss is not None:
                # optimize one step
                network_lr_scheduler.step()
                network_opt.zero_grad()
                total_loss.backward()
                network_opt.step()

            # save checkpoint
            with torch.no_grad():
                if last_checkpoint_step is None or train_step - last_checkpoint_step >= conf.checkpoint_interval:
                    utils.printout(conf.flog, 'Saving checkpoint ...... ')
                    utils.save_checkpoint(models=models, model_names=model_names, dirname=os.path.join(conf.exp_dir, 'ckpts'), \
                            epoch=epoch, prepend_epoch=True, optimizers=optimizers, optimizer_names=model_names)
                    utils.printout(conf.flog, 'DONE')
                    last_checkpoint_step = train_step

            # validate one batch
            while val_fraction_done <= train_fraction_done and val_batch_ind+1 < val_num_batch:
                val_batch_ind, val_batch = next(val_batches)

                val_fraction_done = (val_batch_ind + 1) / val_num_batch
                val_step = (epoch + val_fraction_done) * train_num_batch - 1

                log_console = not conf.no_console_log and (last_val_console_log_step is None or \
                        val_step - last_val_console_log_step >= conf.console_log_interval)
                if log_console:
                    last_val_console_log_step = val_step

                # set models to evaluation mode
                for m in models:
                    m.eval()

                with torch.no_grad():
                    # forward pass (including logging)
                    __ = forward(batch=val_batch, data_features=data_features, network=network, conf=conf, is_val=True, \
                            step=val_step, epoch=epoch, batch_ind=val_batch_ind, num_batch=val_num_batch, start_time=start_time, \
                            log_console=log_console, log_tb=not conf.no_tb_log, tb_writer=val_writer, lr=network_opt.param_groups[0]['lr'])
           
    # save the final models
    utils.printout(conf.flog, 'Saving final checkpoint ...... ')
    utils.save_checkpoint(models=models, model_names=model_names, dirname=os.path.join(conf.exp_dir, 'ckpts'), \
            epoch=epoch, prepend_epoch=False, optimizers=optimizers, optimizer_names=optimizer_names)
    utils.printout(conf.flog, 'DONE')
Пример #4
0
def train(conf):
    # load network model
    models = utils.get_model_module(conf.model_version)

    # check if training run already exists. If so, delete it.
    if os.path.exists(os.path.join(conf.log_path, conf.exp_name)) or \
       os.path.exists(os.path.join(conf.ckpt_path, conf.exp_name)):
        response = input(
            'A training run named "%s" already exists, overwrite? (y/n) ' %
            (conf.exp_name))
        if response != 'y':
            sys.exit()
    if os.path.exists(os.path.join(conf.log_path, conf.exp_name)):
        shutil.rmtree(os.path.join(conf.log_path, conf.exp_name))
    if os.path.exists(os.path.join(conf.ckpt_path, conf.exp_name)):
        shutil.rmtree(os.path.join(conf.ckpt_path, conf.exp_name))

    # create directories for this run
    os.makedirs(os.path.join(conf.ckpt_path, conf.exp_name))
    os.makedirs(os.path.join(conf.log_path, conf.exp_name))

    # file log
    flog = open(os.path.join(conf.log_path, conf.exp_name, 'train.log'), 'w')

    # backup python files used for this training
    os.system('cp config.py data.py %s.py %s %s' %
              (conf.model_version, __file__,
               os.path.join(conf.log_path, conf.exp_name)))

    # set training device
    device = torch.device(conf.device)
    print(f'Using device: {conf.device}')
    flog.write(f'Using device: {conf.device}\n')

    # log the object category information
    print(f'Object Category: {conf.category}')
    flog.write(f'Object Category: {conf.category}\n')

    # control randomness
    if conf.seed < 0:
        conf.seed = random.randint(1, 10000)
    print("Random Seed: %d" % (conf.seed))
    flog.write(f'Random Seed: {conf.seed}\n')
    random.seed(conf.seed)
    np.random.seed(conf.seed)
    torch.manual_seed(conf.seed)

    # save config
    torch.save(conf, os.path.join(conf.ckpt_path, conf.exp_name, 'conf.pth'))

    # create models
    encoder = models.RecursiveEncoder(conf,
                                      variational=True,
                                      probabilistic=not conf.non_variational)
    decoder = models.RecursiveDecoder(conf)
    models = [encoder, decoder]
    model_names = ['encoder', 'decoder']

    # create optimizers
    encoder_opt = torch.optim.Adam(encoder.parameters(), lr=conf.lr)
    decoder_opt = torch.optim.Adam(decoder.parameters(), lr=conf.lr)
    optimizers = [encoder_opt, decoder_opt]
    optimizer_names = ['encoder', 'decoder']

    # learning rate scheduler
    encoder_scheduler = torch.optim.lr_scheduler.StepLR(
        encoder_opt, step_size=conf.lr_decay_every, gamma=conf.lr_decay_by)
    decoder_scheduler = torch.optim.lr_scheduler.StepLR(
        decoder_opt, step_size=conf.lr_decay_every, gamma=conf.lr_decay_by)

    # create training and validation datasets and data loaders
    data_features = ['object', 'diff']
    train_dataset = SynShapesDiffDataset(conf.train_dataset, data_features,
                                         conf.shapediff_topk)
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=conf.batch_size,
        shuffle=True,
        collate_fn=utils.collate_feats)

    # create logs
    if not conf.no_console_log:
        header = '     Time    Epoch     Dataset    Iteration    Progress(%)       LR       BoxLoss   StructLoss  DNTypeLoss  DNBoxLoss  KLDivLoss    L1Loss   TotalLoss'
    if not conf.no_tb_log:
        # https://github.com/lanpa/tensorboard-pytorch
        # from tensorboardX import SummaryWriter
        from torch.utils.tensorboard import SummaryWriter
        tb_writer = SummaryWriter(
            log_dir=os.path.join(conf.log_path, conf.exp_name, 'train'))

    # send parameters to device
    for m in models:
        m.to(device)
    for o in optimizers:
        utils.optimizer_to_device(o, device)

    # start training
    print("Starting training ...... ")
    flog.write('Starting training ......\n')

    start_time = time.time()

    last_checkpoint_step = None
    last_train_console_log_step, last_valdt_console_log_step = None, None
    train_num_batch = len(train_dataloader)

    # train for every epoch
    for epoch in range(conf.epochs):
        if not conf.no_console_log:
            print(f'training run {conf.exp_name}')
            flog.write(f'training run {conf.exp_name}\n')
            print(header)
            flog.write(header + '\n')

        train_batches = enumerate(train_dataloader, 0)
        train_fraction_done = 0.0

        # train for every batch
        for train_batch_ind, batch in train_batches:
            train_fraction_done = (train_batch_ind + 1) / train_num_batch
            train_step = epoch * train_num_batch + train_batch_ind

            log_console = not conf.no_console_log and (last_train_console_log_step is None or \
                    train_step - last_train_console_log_step >= conf.console_log_interval)
            if log_console:
                last_train_console_log_step = train_step

            # set models to training mode
            for m in models:
                m.train()

            # forward pass (including logging)
            total_loss = forward(batch=batch,
                                 data_features=data_features,
                                 encoder=encoder,
                                 decoder=decoder,
                                 device=device,
                                 conf=conf,
                                 is_valdt=False,
                                 step=train_step,
                                 epoch=epoch,
                                 batch_ind=train_batch_ind,
                                 num_batch=train_num_batch,
                                 start_time=start_time,
                                 log_console=log_console,
                                 log_tb=not conf.no_tb_log,
                                 tb_writer=tb_writer,
                                 lr=encoder_opt.param_groups[0]['lr'],
                                 flog=flog)

            # optimize one step
            encoder_opt.zero_grad()
            decoder_opt.zero_grad()
            total_loss.backward()
            if False:
                for name, param in encoder.named_parameters():
                    if param.requires_grad and param.grad is not None:
                        print(name, torch.norm(param.grad.data))
                for name, param in decoder.named_parameters():
                    if param.requires_grad and param.grad is not None:
                        print(name, torch.norm(param.grad.data))
                exit(1)
            encoder_opt.step()
            decoder_opt.step()
            encoder_scheduler.step()
            decoder_scheduler.step()

            # save checkpoint
            with torch.no_grad():
                if last_checkpoint_step is None or \
                        train_step - last_checkpoint_step >= conf.checkpoint_interval:
                    print("Saving checkpoint ...... ", end='', flush=True)
                    flog.write("Saving checkpoint ...... ")
                    utils.save_checkpoint(models=models,
                                          model_names=model_names,
                                          dirname=os.path.join(
                                              conf.ckpt_path, conf.exp_name),
                                          epoch=epoch,
                                          prepend_epoch=True,
                                          optimizers=optimizers,
                                          optimizer_names=model_names)
                    print("DONE")
                    flog.write("DONE\n")
                    last_checkpoint_step = train_step

    # save the final models
    print("Saving final checkpoint ...... ", end='', flush=True)
    flog.write("Saving final checkpoint ...... ")
    utils.save_checkpoint(models=models,
                          model_names=model_names,
                          dirname=os.path.join(conf.ckpt_path, conf.exp_name),
                          epoch=epoch,
                          prepend_epoch=False,
                          optimizers=optimizers,
                          optimizer_names=optimizer_names)
    print("DONE")
    flog.write("DONE\n")

    flog.close()
Пример #5
0
def train(conf, train_shape_list, train_data_list, val_data_list,
          all_train_data_list):
    # create training and validation datasets and data loaders
    data_features = ['pcs', 'pc_pxids', 'pc_movables', 'gripper_img_target', 'gripper_direction_camera', 'gripper_forward_direction_camera', \
            'result', 'cur_dir', 'shape_id', 'trial_id', 'is_original']

    # load network model
    model_def = utils.get_model_module(conf.model_version)

    # create models
    network = model_def.Network(conf.feat_dim)
    utils.printout(conf.flog, '\n' + str(network) + '\n')

    # create optimizers
    network_opt = torch.optim.Adam(network.parameters(),
                                   lr=conf.lr,
                                   weight_decay=conf.weight_decay)

    # learning rate scheduler
    network_lr_scheduler = torch.optim.lr_scheduler.StepLR(
        network_opt, step_size=conf.lr_decay_every, gamma=conf.lr_decay_by)

    # create logs
    if not conf.no_console_log:
        header = '     Time    Epoch     Dataset    Iteration    Progress(%)       LR    TotalLoss'
    if not conf.no_tb_log:
        # https://github.com/lanpa/tensorboard-pytorch
        from tensorboardX import SummaryWriter
        train_writer = SummaryWriter(os.path.join(conf.exp_dir, 'train'))
        val_writer = SummaryWriter(os.path.join(conf.exp_dir, 'val'))

    # send parameters to device
    network.to(conf.device)
    utils.optimizer_to_device(network_opt, conf.device)

    # load dataset
    train_dataset = SAPIENVisionDataset([conf.primact_type], conf.category_types, data_features, conf.buffer_max_num, \
            abs_thres=conf.abs_thres, rel_thres=conf.rel_thres, dp_thres=conf.dp_thres, img_size=conf.img_size, no_true_false_equal=conf.no_true_false_equal)

    val_dataset = SAPIENVisionDataset([conf.primact_type], conf.category_types, data_features, conf.buffer_max_num, \
            abs_thres=conf.abs_thres, rel_thres=conf.rel_thres, dp_thres=conf.dp_thres, img_size=conf.img_size, no_true_false_equal=conf.no_true_false_equal)
    val_dataset.load_data(val_data_list)
    utils.printout(conf.flog, str(val_dataset))

    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=conf.batch_size, shuffle=False, pin_memory=True, \
            num_workers=0, drop_last=True, collate_fn=utils.collate_feats, worker_init_fn=utils.worker_init_fn)
    val_num_batch = len(val_dataloader)

    # create a data generator
    datagen = DataGen(conf.num_processes_for_datagen, conf.flog)

    # sample succ
    if conf.sample_succ:
        sample_succ_list = []
        sample_succ_dirs = []

    # start training
    start_time = time.time()

    last_train_console_log_step, last_val_console_log_step = None, None

    # if resume
    start_epoch = 0
    if conf.resume:
        # figure out the latest epoch to resume
        for item in os.listdir(os.path.join(conf.exp_dir, 'ckpts')):
            if item.endswith('-train_dataset.pth'):
                start_epoch = int(item.split('-')[0])

        # load states for network, optimizer, lr_scheduler, sample_succ_list
        data_to_restore = torch.load(
            os.path.join(conf.exp_dir, 'ckpts',
                         '%d-network.pth' % start_epoch))
        network.load_state_dict(data_to_restore)
        data_to_restore = torch.load(
            os.path.join(conf.exp_dir, 'ckpts',
                         '%d-optimizer.pth' % start_epoch))
        network_opt.load_state_dict(data_to_restore)
        data_to_restore = torch.load(
            os.path.join(conf.exp_dir, 'ckpts',
                         '%d-lr_scheduler.pth' % start_epoch))
        network_lr_scheduler.load_state_dict(data_to_restore)

        # rmdir and make a new dir for the current sample-succ directory
        old_sample_succ_dir = os.path.join(
            conf.data_dir, 'epoch-%04d_sample-succ' % (start_epoch - 1))
        utils.force_mkdir(old_sample_succ_dir)

    # train for every epoch
    for epoch in range(start_epoch, conf.epochs):
        ### collect data for the current epoch
        if epoch > start_epoch:
            utils.printout(
                conf.flog,
                f'  [{strftime("%H:%M:%S", time.gmtime(time.time()-start_time)):>9s} Waiting epoch-{epoch} data ]'
            )
            train_data_list = datagen.join_all()
            utils.printout(
                conf.flog,
                f'  [{strftime("%H:%M:%S", time.gmtime(time.time()-start_time)):>9s} Gathered epoch-{epoch} data ]'
            )
            cur_data_folders = []
            for item in train_data_list:
                item = '/'.join(item.split('/')[:-1])
                if item not in cur_data_folders:
                    cur_data_folders.append(item)
            for cur_data_folder in cur_data_folders:
                with open(os.path.join(cur_data_folder, 'data_tuple_list.txt'),
                          'w') as fout:
                    for item in train_data_list:
                        if cur_data_folder == '/'.join(item.split('/')[:-1]):
                            fout.write(item.split('/')[-1] + '\n')

            # load offline-generated sample-random data
            for item in all_train_data_list:
                valid_id_l = conf.num_interaction_data_offline + conf.num_interaction_data * (
                    epoch - 1)
                valid_id_r = conf.num_interaction_data_offline + conf.num_interaction_data * epoch
                if valid_id_l <= int(item.split('_')[-1]) < valid_id_r:
                    train_data_list.append(item)

        ### start generating data for the next epoch
        # sample succ
        if conf.sample_succ:
            if conf.resume and epoch == start_epoch:
                sample_succ_list = torch.load(
                    os.path.join(conf.exp_dir, 'ckpts',
                                 '%d-sample_succ_list.pth' % start_epoch))
            else:
                torch.save(
                    sample_succ_list,
                    os.path.join(conf.exp_dir, 'ckpts',
                                 '%d-sample_succ_list.pth' % epoch))
            for item in sample_succ_list:
                datagen.add_one_recollect_job(item[0], item[1], item[2],
                                              item[3], item[4], item[5],
                                              item[6])
            sample_succ_list = []
            sample_succ_dirs = []
            cur_sample_succ_dir = os.path.join(
                conf.data_dir, 'epoch-%04d_sample-succ' % epoch)
            utils.force_mkdir(cur_sample_succ_dir)

        # start all jobs
        datagen.start_all()
        utils.printout(
            conf.flog,
            f'  [ {strftime("%H:%M:%S", time.gmtime(time.time()-start_time)):>9s} Started generating epoch-{epoch+1} data ]'
        )

        ### load data for the current epoch
        if conf.resume and epoch == start_epoch:
            train_dataset = torch.load(
                os.path.join(conf.exp_dir, 'ckpts',
                             '%d-train_dataset.pth' % start_epoch))
        else:
            train_dataset.load_data(train_data_list)
        utils.printout(conf.flog, str(train_dataset))
        train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=conf.batch_size, shuffle=True, pin_memory=True, \
                num_workers=0, drop_last=True, collate_fn=utils.collate_feats, worker_init_fn=utils.worker_init_fn)
        train_num_batch = len(train_dataloader)

        ### print log
        if not conf.no_console_log:
            utils.printout(conf.flog, f'training run {conf.exp_name}')
            utils.printout(conf.flog, header)

        train_batches = enumerate(train_dataloader, 0)
        val_batches = enumerate(val_dataloader, 0)

        train_fraction_done = 0.0
        val_fraction_done = 0.0
        val_batch_ind = -1

        ### train for every batch
        for train_batch_ind, batch in train_batches:
            train_fraction_done = (train_batch_ind + 1) / train_num_batch
            train_step = epoch * train_num_batch + train_batch_ind

            log_console = not conf.no_console_log and (last_train_console_log_step is None or \
                    train_step - last_train_console_log_step >= conf.console_log_interval)
            if log_console:
                last_train_console_log_step = train_step

            # save checkpoint
            if train_batch_ind == 0:
                with torch.no_grad():
                    utils.printout(conf.flog, 'Saving checkpoint ...... ')
                    torch.save(
                        network.state_dict(),
                        os.path.join(conf.exp_dir, 'ckpts',
                                     '%d-network.pth' % epoch))
                    torch.save(
                        network_opt.state_dict(),
                        os.path.join(conf.exp_dir, 'ckpts',
                                     '%d-optimizer.pth' % epoch))
                    torch.save(
                        network_lr_scheduler.state_dict(),
                        os.path.join(conf.exp_dir, 'ckpts',
                                     '%d-lr_scheduler.pth' % epoch))
                    torch.save(
                        train_dataset,
                        os.path.join(conf.exp_dir, 'ckpts',
                                     '%d-train_dataset.pth' % epoch))
                    utils.printout(conf.flog, 'DONE')

            # set models to training mode
            network.train()

            # forward pass (including logging)
            total_loss, whole_feats, whole_pcs, whole_pxids, whole_movables = forward(batch=batch, data_features=data_features, network=network, conf=conf, is_val=False, \
                    step=train_step, epoch=epoch, batch_ind=train_batch_ind, num_batch=train_num_batch, start_time=start_time, \
                    log_console=log_console, log_tb=not conf.no_tb_log, tb_writer=train_writer, lr=network_opt.param_groups[0]['lr'])

            # optimize one step
            network_opt.zero_grad()
            total_loss.backward()
            network_opt.step()
            network_lr_scheduler.step()

            # sample succ
            if conf.sample_succ:
                network.eval()

                with torch.no_grad():
                    # sample a random EE orientation
                    random_up = torch.randn(conf.batch_size,
                                            3).float().to(conf.device)
                    random_forward = torch.randn(conf.batch_size,
                                                 3).float().to(conf.device)
                    random_left = torch.cross(random_up, random_forward)
                    random_forward = torch.cross(random_left, random_up)
                    random_dirs1 = F.normalize(random_up, dim=1).float()
                    random_dirs2 = F.normalize(random_forward, dim=1).float()

                    # test over the entire image
                    whole_pc_scores1 = network.inference_whole_pc(
                        whole_feats, random_dirs1, random_dirs2)  # B x N
                    whole_pc_scores2 = network.inference_whole_pc(
                        whole_feats, -random_dirs1, random_dirs2)  # B x N

                    # add to the sample_succ_list if wanted
                    ss_cur_dir = batch[data_features.index('cur_dir')]
                    ss_shape_id = batch[data_features.index('shape_id')]
                    ss_trial_id = batch[data_features.index('trial_id')]
                    ss_is_original = batch[data_features.index('is_original')]
                    for i in range(conf.batch_size):
                        valid_id_l = conf.num_interaction_data_offline + conf.num_interaction_data * (
                            epoch - 1)
                        valid_id_r = conf.num_interaction_data_offline + conf.num_interaction_data * epoch

                        if ('sample-succ' not in ss_cur_dir[i]) and (ss_is_original[i]) and (ss_cur_dir[i] not in sample_succ_dirs) \
                                and (valid_id_l <= int(ss_trial_id[i]) < valid_id_r):
                            sample_succ_dirs.append(ss_cur_dir[i])

                            # choose one from the two options
                            gt_movable = whole_movables[i].cpu().numpy()

                            whole_pc_score1 = whole_pc_scores1[i].cpu().numpy(
                            ) * gt_movable
                            whole_pc_score1[whole_pc_score1 < 0.5] = 0
                            whole_pc_score_sum1 = np.sum(
                                whole_pc_score1) + 1e-12

                            whole_pc_score2 = whole_pc_scores2[i].cpu().numpy(
                            ) * gt_movable
                            whole_pc_score2[whole_pc_score2 < 0.5] = 0
                            whole_pc_score_sum2 = np.sum(
                                whole_pc_score2) + 1e-12

                            choose1or2_ratio = whole_pc_score_sum1 / (
                                whole_pc_score_sum1 + whole_pc_score_sum2)
                            random_dir1 = random_dirs1[i].cpu().numpy()
                            random_dir2 = random_dirs2[i].cpu().numpy()
                            if np.random.random() < choose1or2_ratio:
                                whole_pc_score = whole_pc_score1
                            else:
                                whole_pc_score = whole_pc_score2
                                random_dir1 = -random_dir1

                            # sample <X, Y> on each img
                            pp = whole_pc_score + 1e-12
                            ptid = np.random.choice(len(whole_pc_score),
                                                    1,
                                                    p=pp / pp.sum())
                            X = whole_pxids[i, ptid, 0].item()
                            Y = whole_pxids[i, ptid, 1].item()

                            # add job to the queue
                            str_cur_dir1 = ',' + ','.join(
                                ['%f' % elem for elem in random_dir1])
                            str_cur_dir2 = ',' + ','.join(
                                ['%f' % elem for elem in random_dir2])
                            sample_succ_list.append((conf.offline_data_dir, str_cur_dir1, str_cur_dir2, \
                                    ss_cur_dir[i].split('/')[-1], cur_sample_succ_dir, X, Y))

            # validate one batch
            while val_fraction_done <= train_fraction_done and val_batch_ind + 1 < val_num_batch:
                val_batch_ind, val_batch = next(val_batches)

                val_fraction_done = (val_batch_ind + 1) / val_num_batch
                val_step = (epoch + val_fraction_done) * train_num_batch - 1

                log_console = not conf.no_console_log and (last_val_console_log_step is None or \
                        val_step - last_val_console_log_step >= conf.console_log_interval)
                if log_console:
                    last_val_console_log_step = val_step

                # set models to evaluation mode
                network.eval()

                with torch.no_grad():
                    # forward pass (including logging)
                    __ = forward(batch=val_batch, data_features=data_features, network=network, conf=conf, is_val=True, \
                            step=val_step, epoch=epoch, batch_ind=val_batch_ind, num_batch=val_num_batch, start_time=start_time, \
                            log_console=log_console, log_tb=not conf.no_tb_log, tb_writer=val_writer, lr=network_opt.param_groups[0]['lr'])