示例#1
0
def train():
    with tf.Graph().as_default(), tf.device('/cpu:0'):
        # Global graph
        global_step = tf.train.get_or_create_global_step()
        # Data
        images, labels, num_classes, num_examples = train_inputs(
            FLAGS.data_list_path,
            FLAGS.batch_size,
            FLAGS.is_color,
            input_height=FLAGS.input_height,
            input_width=FLAGS.input_width,
            augment=False,
            num_preprocess_threads=FLAGS.num_threads_per_gpu * FLAGS.num_gpus)
        batches_per_epoch = num_examples // FLAGS.batch_size + 1

        # Network
        network = net_select(FLAGS.net_name, data_format=FLAGS.data_format)

        # DataParallel
        model = DataParallel(network,
                             init_lr=FLAGS.init_lr,
                             decay_epoch=FLAGS.lr_decay_epoch,
                             decay_rate=FLAGS.lr_decay_rate,
                             batches_per_epoch=batches_per_epoch,
                             num_gpus=FLAGS.num_gpus)

        # Inference
        train_ops, lr, losses, losses_name = model(images=images,
                                                   labels=labels,
                                                   num_classes=num_classes)

        # Saver
        saver = tf.train.Saver(tf.global_variables(),
                               max_to_keep=20,
                               builder=DataParallelSaverBuilder())

        # Supervisor
        sv = tf.train.Supervisor(logdir=os.path.join(
            FLAGS.train_dir, FLAGS.net_name + '_' + FLAGS.model_name),
                                 local_init_op=get_local_init_ops(),
                                 saver=saver,
                                 global_step=global_step,
                                 save_model_secs=0)

        # Session config
        config = tf.ConfigProto(allow_soft_placement=True,
                                log_device_placement=False,
                                gpu_options=tf.GPUOptions(allow_growth=True))

        # Format string
        format_str = '[%s] Epoch/Step %d/%d, lr = %g\n'
        for loss_id, loss_name in enumerate(losses_name):
            format_str += '[%s]    Loss #' + str(
                loss_id) + ': ' + loss_name + ' = %.6f\n'
        format_str += '[%s]    batch_time = %.1fms/batch, throughput = %.1fimages/s'

        # Training session
        with sv.managed_session(config=config) as sess:
            ckpt = tf.train.get_checkpoint_state(
                os.path.join(FLAGS.model_dir,
                             FLAGS.net_name + '_' + FLAGS.model_name))
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                print('Model restored from %s' % os.path.join(
                    FLAGS.model_dir, FLAGS.net_name + '_' + FLAGS.model_name))
            else:
                print('Network parameters initialized from scratch.')

            print('%s training start...' %
                  (FLAGS.net_name + '_' + FLAGS.model_name))
            step = 0
            epoch = 1
            while epoch <= FLAGS.max_epoches:
                step = sess.run(global_step)
                epoch = step // batches_per_epoch + 1

                start_time = time.time()
                output = sess.run([train_ops, lr] + losses)
                learning_rate = output[1]
                losses_value = output[2:]
                duration = time.time() - start_time

                if step % FLAGS.display_interval == 0:
                    examples_per_sec = FLAGS.batch_size / duration
                    sec_per_batch = duration * 1000

                    # Format tuple
                    format_list = [datetime.now(), epoch, step, learning_rate]
                    for loss_value in losses_value:
                        format_list.extend([datetime.now(), loss_value])
                    format_list.extend(
                        [datetime.now(), sec_per_batch, examples_per_sec])
                    print(format_str % tuple(format_list))

                    if (step > 0 and step % FLAGS.save_interval == 0
                        ) or step == FLAGS.max_epoches * batches_per_epoch:
                        train_path = os.path.join(
                            FLAGS.model_dir,
                            FLAGS.net_name + '_' + FLAGS.model_name,
                            FLAGS.net_name + '_' + FLAGS.model_name + '.ckpt')
                        saver.save(sess, train_path, global_step=step)
                        print('[%s]: Model has been saved in Iteration %d' %
                              (datetime.now(), step))
示例#2
0
    mod = importlib.import_module(cfg.MODULE)
    Net = mod.Net

    torch.manual_seed(cfg.RANDOM_SEED)

    image_dataset = GMDataset(cfg.DATASET_FULL_NAME,
                              sets='test',
                              length=cfg.EVAL.SAMPLES,
                              obj_resize=cfg.PAIR.RESCALE)
    dataloader = get_dataloader(image_dataset)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    model = Net()
    model = model.cuda()
    model = DataParallel(model, device_ids=range(torch.cuda.device_count()))

    if not Path(cfg.OUTPUT_PATH).exists():
        Path(cfg.OUTPUT_PATH).mkdir(parents=True)
    now_time = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    with DupStdoutFileManager(
            str(Path(cfg.OUTPUT_PATH) /
                ('eval_log_' + now_time + '.log'))) as _:
        print_easydict(cfg)
        classes = dataloader.dataset.classes
        pcks = eval_model(
            model,
            dataloader,
            eval_epoch=cfg.EVAL.EPOCH if cfg.EVAL.EPOCH != 0 else None,
            verbose=True)
示例#3
0
    dataloader = {x: get_dataloader(image_dataset[x], fix_seed=(x == 'test'))
        for x in ('train', 'test')}

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    model = Net()
    model = model.cuda()

    if cfg.TRAIN.LOSS_FUNC == 'offset':
        criterion = RobustLoss(norm=cfg.TRAIN.RLOSS_NORM)
    elif cfg.TRAIN.LOSS_FUNC == 'perm':
        criterion = CrossEntropyLoss()
    else:
        raise ValueError('Unknown loss function {}'.format(cfg.TRAIN.LOSS_FUNC))

    optimizer = optim.SGD(model.parameters(), lr=cfg.TRAIN.LR, momentum=cfg.TRAIN.MOMENTUM, nesterov=True)

    model = DataParallel(model, device_ids=cfg.GPUS)

    if not Path(cfg.OUTPUT_PATH).exists():
        Path(cfg.OUTPUT_PATH).mkdir(parents=True)

    now_time = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    tfboardwriter = SummaryWriter(logdir=str(Path(cfg.OUTPUT_PATH) / 'tensorboard' / 'training_{}'.format(now_time)))

    with DupStdoutFileManager(str(Path(cfg.OUTPUT_PATH) / ('train_log_' + now_time + '.log'))) as _:
        print_easydict(cfg)
        model = train_eval_model(model, criterion, optimizer, dataloader, tfboardwriter,
                                 num_epochs=cfg.TRAIN.NUM_EPOCHS,
                                 resume=cfg.TRAIN.START_EPOCH != 0,
                                 start_epoch=cfg.TRAIN.START_EPOCH)