コード例 #1
0
def create_trainer(network, minibatch_size, epoch_size, num_quantization_bits, block_size, warm_up, progress_printer):
    lr_per_mb = [1.0]*30 + [0.1]*30 + [0.01]*20 + [0.001]
    l2_reg_weight = 0.0001

    # adjust LR with minibatch size
    if minibatch_size != 256:
        for i in range(0, len(lr_per_mb)):
            lr_per_mb[i] *= minibatch_size / 256

    # Set learning parameters
    lr_schedule = learning_rate_schedule(lr_per_mb, epoch_size=epoch_size, unit=UnitType.minibatch)
    mm_schedule = momentum_schedule(0.9)

    local_learner = nesterov(network['output'].parameters, lr_schedule, mm_schedule,
                             l2_regularization_weight=l2_reg_weight)

    # learner object
    if block_size != None and num_quantization_bits != 32:
        raise RuntimeError("Block momentum cannot be used with quantization, please remove quantized_bits option.")

    if block_size != None:
        learner = block_momentum_distributed_learner(local_learner, block_size=block_size)
    else:
        learner = data_parallel_distributed_learner(local_learner, num_quantization_bits=num_quantization_bits, distributed_after=warm_up)

    return Trainer(network['output'], (network['ce'], network['errs']), learner, progress_printer)
コード例 #2
0
    def create_trainer(self):
        try:
            p = self.output.parameters
            # Three of four parameters are learned by block_momentum_distributed_learner.
            bmd_learner = cntk.block_momentum_distributed_learner(
                cntk.momentum_sgd(
                    [p[0], p[1], p[2]],
                    cntk.learning_parameter_schedule(0.0001),
                    cntk.momentum_as_time_constant_schedule(1000)),
                block_size=1000,
                block_learning_rate=0.01,
                block_momentum_as_time_constant=1000)

            # New API to mark which learner is to use for metric aggregaion.
            bmd_learner.set_as_metric_aggregator()

            # The last parameter is learned by the data_parallel_distributed_learner.
            momentum_schedule = cntk.momentum_schedule_per_sample(
                0.9990913221888589)
            lr_per_sample = cntk.learning_parameter_schedule_per_sample(0.007)
            dpd_learner = cntk.data_parallel_distributed_learner(
                cntk.momentum_sgd([p[3]], lr_per_sample, momentum_schedule,
                                  True))

            comm_rank = cntk.distributed.Communicator.rank()
            self.trainer = cntk.Trainer(
                self.output, (self.ce, self.err), [bmd_learner, dpd_learner], [
                    cntk.logging.ProgressPrinter(
                        freq=progress_freq, tag="Training", rank=comm_rank)
                ])
        except RuntimeError:
            self.trainer = None
        return
 def create_distributed_learner(self, mode, config):
     local_learner = C.sgd(self.z.parameters,
                           C.learning_parameter_schedule_per_sample(0.01))
     try:
         if mode == 'data_parallel':
             if config is None:
                 config = DataParallelConfig(num_quantization_bits=32,
                                             distributed_after=0)
             learner = C.data_parallel_distributed_learner(
                 local_learner,
                 num_quantization_bits=config.num_quantization_bits,
                 distributed_after=config.distributed_after)
         elif mode == 'block_momentum':
             if config is None:
                 # the default config to match data parallel SGD
                 config = BlockMomentumConfig(
                     block_momentum_as_time_constant=0,
                     block_learning_rate=1,
                     block_size=NUM_WORKERS,
                     distributed_after=0)
             learner = C.block_momentum_distributed_learner(
                 local_learner,
                 block_momentum_as_time_constant=config.
                 block_momentum_as_time_constant,
                 block_learning_rate=config.block_learning_rate,
                 block_size=config.block_size,
                 distributed_after=config.distributed_after)
         else:
             learner = local_learner
     except RuntimeError:
         learner = None
     return learner
コード例 #4
0
def create_trainer(network, minibatch_size, epoch_size, num_quantization_bits, block_size, warm_up, progress_printer):
    if network['name'] == 'resnet20': 
        lr_per_mb = [1.0]*80+[0.1]*40+[0.01]
    elif network['name'] == 'resnet110': 
        lr_per_mb = [0.1]*1+[1.0]*80+[0.1]*40+[0.01]
    else: 
        return RuntimeError("Unknown model name!")

    momentum_time_constant = -minibatch_size/np.log(0.9)
    l2_reg_weight = 0.0001

    # Set learning parameters
    lr_per_sample = [lr/minibatch_size for lr in lr_per_mb]
    lr_schedule = learning_rate_schedule(lr_per_sample, epoch_size=epoch_size, unit=UnitType.sample)
    mm_schedule = momentum_as_time_constant_schedule(momentum_time_constant)
    
    # learner object
    if block_size != None and num_quantization_bits != 32:
        raise RuntimeError("Block momentum cannot be used with quantization, please remove quantized_bits option.")

    local_learner = momentum_sgd(network['output'].parameters, lr_schedule, mm_schedule,
                                 l2_regularization_weight = l2_reg_weight)

    if block_size != None:
        learner = block_momentum_distributed_learner(local_learner, block_size=block_size)
    else:
        learner = data_parallel_distributed_learner(local_learner, num_quantization_bits=num_quantization_bits, distributed_after=warm_up)
    
    return Trainer(network['output'], (network['ce'], network['pe']), learner, progress_printer)
コード例 #5
0
def create_trainer(network, minibatch_size, epoch_size, num_quantization_bits, block_size, warm_up, progress_printer):
    if network['name'] == 'resnet20': 
        lr_per_mb = [1.0]*80+[0.1]*40+[0.01]
    elif network['name'] == 'resnet110': 
        lr_per_mb = [0.1]*1+[1.0]*80+[0.1]*40+[0.01]
    else: 
        return RuntimeError("Unknown model name!")

    momentum_time_constant = -minibatch_size/np.log(0.9)
    l2_reg_weight = 0.0001

    # Set learning parameters
    lr_per_sample = [lr/minibatch_size for lr in lr_per_mb]
    lr_schedule = learning_rate_schedule(lr_per_sample, epoch_size=epoch_size, unit=UnitType.sample)
    mm_schedule = momentum_as_time_constant_schedule(momentum_time_constant)
    
    # learner object
    if block_size != None and num_quantization_bits != 32:
        raise RuntimeError("Block momentum cannot be used with quantization, please remove quantized_bits option.")

    local_learner = momentum_sgd(network['output'].parameters, lr_schedule, mm_schedule,
                                 l2_regularization_weight = l2_reg_weight)

    if block_size != None:
        learner = block_momentum_distributed_learner(local_learner, block_size=block_size)
    else:
        learner = data_parallel_distributed_learner(local_learner, num_quantization_bits=num_quantization_bits, distributed_after=warm_up)
    
    return Trainer(network['output'], (network['ce'], network['pe']), learner, progress_printer)
コード例 #6
0
def create_trainer(network, minibatch_size, epoch_size, num_quantization_bits, block_size, warm_up, progress_printer):
    lr_per_mb = [0.1] # [1.0]*30 + [0.1]*30 + [0.01]*20 + [0.001]
    l2_reg_weight = 0.0001

    # adjust LR with minibatch size
    #if minibatch_size != 256:
    #    for i in range(0, len(lr_per_mb)):
    #        lr_per_mb[i] *= minibatch_size / 256

    # Set learning parameters
    lr_schedule = learning_rate_schedule(lr_per_mb, epoch_size=epoch_size, unit=UnitType.minibatch)
    mm_schedule = momentum_schedule(0.9)

    local_learner = nesterov(network['output'].parameters, lr_schedule, mm_schedule,
                             l2_regularization_weight=l2_reg_weight)

    # learner object
    if block_size != None and num_quantization_bits != 32:
        raise RuntimeError("Block momentum cannot be used with quantization, please remove quantized_bits option.")

    if block_size != None:
        learner = block_momentum_distributed_learner(local_learner, block_size=block_size)
    else:
        learner = data_parallel_distributed_learner(local_learner, num_quantization_bits=num_quantization_bits, distributed_after=warm_up)

    return Trainer(network['output'], (network['ce'], network['errs']), learner, progress_printer)
コード例 #7
0
 def __init__(self):
     self.input_dim = 40000
     self.embed_dim = 100
     self.batch_size = 20
     i = C.input_variable((self.input_dim, ), is_sparse=True)
     self.p = C.parameter(shape=(self.input_dim, self.embed_dim), init=1)
     o = C.times(i, self.p)
     z = C.reduce_sum(o)
     learner = C.data_parallel_distributed_learner(
         C.sgd(
             z.parameters,
             C.learning_rate_schedule(0.01,
                                      unit=C.learners.UnitType.sample)))
     self.trainer = C.Trainer(z, (z, None), learner, [])
    def create_trainer(self):
        try:
            lr_per_sample = cntk.learning_parameter_schedule_per_sample(0.007)
            p = self.output.parameters
            # Three of four parameters are learned by first data_parallel_distributed_learner.
            learner1 = cntk.data_parallel_distributed_learner(
                cntk.sgd([p[0], p[1], p[2]], lr_per_sample))

            # New API to mark which learner is to use for metric aggregaion.
            learner1.set_as_metric_aggregator()

            # The last parameter is learned by another data_parallel_distributed_learner.
            learner2 = cntk.data_parallel_distributed_learner(
                cntk.sgd([p[3]], lr_per_sample))

            comm_rank = cntk.distributed.Communicator.rank()
            self.trainer = cntk.Trainer(
                self.output, (self.ce, self.err), [learner1, learner2], [
                    cntk.logging.ProgressPrinter(
                        freq=progress_freq, tag="Training", rank=comm_rank)
                ])
        except RuntimeError:
            self.trainer = None
        return
    def create_trainer(self):
        try:
            lr_per_sample = cntk.learning_parameter_schedule_per_sample(0.007)
            learner = cntk.data_parallel_distributed_learner(
                cntk.sgd(self.output.parameters, lr_per_sample))

            comm_rank = cntk.distributed.Communicator.rank()
            self.trainer = cntk.Trainer(
                self.output, (self.ce, self.err), [learner], [
                    cntk.logging.ProgressPrinter(
                        freq=progress_freq, tag="Training", rank=comm_rank)
                ])
        except RuntimeError:
            self.trainer = None
        return
コード例 #10
0
 def create_distributed_learner(self, mode, config):
     local_learner = C.sgd(self.z.parameters, C.learning_parameter_schedule_per_sample(0.01))
     try:
         if mode == 'data_parallel':
             if config is None:
                 config = DataParallelConfig(num_quantization_bits=32, distributed_after=0)
             learner = C.data_parallel_distributed_learner(local_learner, num_quantization_bits=config.num_quantization_bits, distributed_after=config.distributed_after)
         elif mode == 'block_momentum':
             if config is None:
                 # the default config to match data parallel SGD
                 config = BlockMomentumConfig(block_momentum_as_time_constant=0, block_learning_rate=1, block_size=NUM_WORKERS, distributed_after=0)
             learner = C.block_momentum_distributed_learner(local_learner, block_momentum_as_time_constant=config.block_momentum_as_time_constant, block_learning_rate=config.block_learning_rate, block_size=config.block_size, distributed_after=config.distributed_after)
         else:
             learner = local_learner
     except RuntimeError:
         learner = None
     return learner
コード例 #11
0
ファイル: ct.py プロジェクト: seasky100/DL_benchmarks
    def run(self, iterator, mode='train'):
        report = dict()
        input_var = C.ops.input_variable(np.prod(iterator.iamge_shape),
                                         np.float32)
        label_var = C.ops.input_variable(iterator.batch_size, np.float32)
        model = self.model(input_var,)
        ce = C.losses.cross_entropy_with_softmax(model, label_var)
        pe = C.metrics.classification_error(model, label_var)
        z = cnn(input_var)
        learner = C.learners.momentum_sgd(z.parameters, self.lr_schedule, self.m_schedule)
        if self.is_parallel:
            distributed_learner = \
                    C.data_parallel_distributed_learner(learner=learner,
                                                    distributed_after=0)

        progress_printer = \
                    C.logging.ProgressPrinter(tag='Training',
                                              num_epochs=iterator.niteration)            
        if self.is_parallel:
            trainer = C.Trainer(z, (ce, pe), distributed_learner,
                                progress_printer)
        else:
            trainer = C.Trainer(z, (ce, pe), learner, progress_printer)        

        for idx, (x, t) in enumerate(iterator):
            total_s = time.perf_counter()
            trainer.train_minibatch({input_var : x, label_var : t})
            forward_s = time.perf_counter()
            forward_e = time.perf_counter()
            backward_s = time.perf_counter()
            backward_e = time.perf_counter()            
            total_e = time.perf_counter()
            
            report[idx] = dict(
                forward=forward_e - forward_s,
                backward=backward_e - backward_s,
                total=total_e - total_s
            )
        return report
コード例 #12
0
def train(data_path,
          model_path,
          log_file,
          config_file,
          restore=False,
          profiling=False,
          gen_heartbeat=False):
    polymath = PolyMath(config_file)
    z, loss = polymath.model()
    training_config = importlib.import_module(config_file).training_config

    minibatch_size = training_config['minibatch_size']
    max_epochs = training_config['max_epochs']
    epoch_size = training_config['epoch_size']
    log_freq = training_config['log_freq']

    progress_writers = [
        C.logging.ProgressPrinter(num_epochs=max_epochs,
                                  freq=log_freq,
                                  tag='Training',
                                  log_to_file=log_file,
                                  rank=C.Communicator.rank(),
                                  gen_heartbeat=gen_heartbeat)
    ]

    lr = C.learning_parameter_schedule(training_config['lr'],
                                       minibatch_size=None,
                                       epoch_size=None)

    ema = {}
    dummies = []
    for p in z.parameters:
        ema_p = C.constant(0,
                           shape=p.shape,
                           dtype=p.dtype,
                           name='ema_%s' % p.uid)
        ema[p.uid] = ema_p
        dummies.append(C.reduce_sum(C.assign(ema_p,
                                             0.999 * ema_p + 0.001 * p)))
    dummy = C.combine(dummies)

    learner = C.adadelta(z.parameters, lr)

    if C.Communicator.num_workers() > 1:
        learner = C.data_parallel_distributed_learner(learner,
                                                      num_quantization_bits=1)

    trainer = C.Trainer(z, (loss, None), learner, progress_writers)

    if profiling:
        C.debugging.start_profiler(sync_gpu=True)

    train_data_file = os.path.join(data_path, training_config['train_data'])
    train_data_ext = os.path.splitext(train_data_file)[-1].lower()

    model_file = os.path.join(model_path, model_name)
    model = C.combine(list(z.outputs) + [loss.output])
    label_ab = argument_by_name(loss, 'ab')

    epoch_stat = {'best_val_err': 100, 'best_since': 0, 'val_since': 0}

    if restore and os.path.isfile(model_file):
        trainer.restore_from_checkpoint(model_file)
        epoch_stat['best_val_err'] = validate_model(
            os.path.join(data_path, training_config['val_data']), model,
            polymath)

    def post_epoch_work(epoch_stat):
        trainer.summarize_training_progress()
        epoch_stat['val_since'] += 1

        if epoch_stat['val_since'] == training_config['val_interval']:
            epoch_stat['val_since'] = 0
            temp = dict((p.uid, p.value) for p in z.parameters)
            for p in trainer.model.parameters:
                p.value = ema[p.uid].value
            val_err = validate_model(
                os.path.join(data_path, training_config['val_data']), model,
                polymath)
            if epoch_stat['best_val_err'] > val_err:
                epoch_stat['best_val_err'] = val_err
                epoch_stat['best_since'] = 0
                trainer.save_checkpoint(model_file)
                for p in trainer.model.parameters:
                    p.value = temp[p.uid]
            else:
                epoch_stat['best_since'] += 1
                if epoch_stat['best_since'] > training_config['stop_after']:
                    return False

        if profiling:
            C.debugging.enable_profiler()

        return True

    if train_data_ext == '.ctf':
        mb_source, input_map = create_mb_and_map(loss, train_data_file,
                                                 polymath)

        for epoch in range(max_epochs):
            num_seq = 0
            with tqdm(total=epoch_size, ncols=32,
                      smoothing=0.1) as progress_bar:
                while True:
                    if trainer.total_number_of_samples_seen >= training_config[
                            'distributed_after']:
                        data = mb_source.next_minibatch(
                            minibatch_size * C.Communicator.num_workers(),
                            input_map=input_map,
                            num_data_partitions=C.Communicator.num_workers(),
                            partition_index=C.Communicator.rank())
                    else:
                        data = mb_source.next_minibatch(minibatch_size,
                                                        input_map=input_map)

                    trainer.train_minibatch(data)
                    num_seq += trainer.previous_minibatch_sample_count
                    dummy.eval()
                    if num_seq >= epoch_size:
                        break
                    else:
                        progress_bar.update(
                            trainer.previous_minibatch_sample_count)
                if not post_epoch_work(epoch_stat):
                    break
    else:
        if train_data_ext != '.tsv':
            raise Exception("Unsupported format")

        minibatch_seqs = training_config[
            'minibatch_seqs']  # number of sequences

        for epoch in range(max_epochs):  # loop over epochs
            tsv_reader = create_tsv_reader(loss, train_data_file, polymath,
                                           minibatch_seqs,
                                           C.Communicator.num_workers())
            minibatch_count = 0
            for data in tsv_reader:
                if (minibatch_count %
                        C.Communicator.num_workers()) == C.Communicator.rank():
                    trainer.train_minibatch(data)  # update model with it
                    dummy.eval()
                minibatch_count += 1
            if not post_epoch_work(epoch_stat):
                break

    if profiling:
        C.debugging.stop_profiler()
コード例 #13
0
def train(data_path,
          model_path,
          log_file,
          config_file,
          restore=False,
          profiling=False,
          gen_heartbeat=False):
    training_config = importlib.import_module(config_file).training_config
    # config for using multi GPUs
    if training_config['multi_gpu']:
        gpu_pad = training_config['gpu_pad']
        gpu_cnt = training_config['gpu_cnt']
        my_rank = C.Communicator.rank()
        my_gpu_id = (my_rank + gpu_pad) % gpu_cnt
        print("rank = " + str(my_rank) + ", using gpu " + str(my_gpu_id) +
              " of " + str(gpu_cnt))
        C.try_set_default_device(C.gpu(my_gpu_id))
    else:
        C.try_set_default_device(C.gpu(0))
    # outputs while training
    normal_log = os.path.join(data_path, training_config['logdir'], log_file)
    # tensorboard files' dir
    tensorboard_logdir = os.path.join(data_path, training_config['logdir'],
                                      log_file)

    polymath = PolyMath(config_file)
    z, loss = polymath.model()

    max_epochs = training_config['max_epochs']
    log_freq = training_config['log_freq']

    progress_writers = [
        C.logging.ProgressPrinter(num_epochs=max_epochs,
                                  freq=log_freq,
                                  tag='Training',
                                  log_to_file=normal_log,
                                  rank=C.Communicator.rank(),
                                  gen_heartbeat=gen_heartbeat)
    ]
    # add tensorboard writer for visualize
    tensorboard_writer = C.logging.TensorBoardProgressWriter(
        freq=10,
        log_dir=tensorboard_logdir,
        rank=C.Communicator.rank(),
        model=z)
    progress_writers.append(tensorboard_writer)

    lr = C.learning_parameter_schedule(training_config['lr'],
                                       minibatch_size=None,
                                       epoch_size=None)

    ema = {}
    dummies_info = {}
    dummies = []
    for p in z.parameters:
        ema_p = C.constant(0,
                           shape=p.shape,
                           dtype=p.dtype,
                           name='ema_%s' % p.uid)
        ema[p.uid] = ema_p
        dummies.append(C.reduce_sum(C.assign(ema_p, p)))
        dummies_info[dummies[-1].output] = (p.name, p.shape)
    dummy = C.combine(dummies)

    learner = C.adadelta(z.parameters, lr)

    if C.Communicator.num_workers() > 1:
        learner = C.data_parallel_distributed_learner(learner)

    trainer = C.Trainer(z, (loss, None), learner, progress_writers)

    if profiling:
        C.debugging.start_profiler(sync_gpu=True)

    train_data_file = os.path.join(data_path, training_config['train_data'])
    train_data_ext = os.path.splitext(train_data_file)[-1].lower()

    model_file = os.path.join(model_path, model_name)
    model = C.combine(list(z.outputs) + [loss.output])
    label_ab = argument_by_name(loss, 'ab')

    epoch_stat = {
        'best_val_err': 100,
        'best_since': 0,
        'val_since': 0,
        'record_num': 0
    }

    if restore and os.path.isfile(model_file):
        trainer.restore_from_checkpoint(model_file)
        #after restore always re-evaluate
        epoch_stat['best_val_err'] = validate_model(
            os.path.join(data_path, training_config['val_data']), model,
            polymath, config_file)

    def post_epoch_work(epoch_stat):
        trainer.summarize_training_progress()
        epoch_stat['val_since'] += 1

        if epoch_stat['val_since'] == training_config['val_interval']:
            epoch_stat['val_since'] = 0
            temp = dict((p.uid, p.value) for p in z.parameters)
            for p in trainer.model.parameters:
                p.value = ema[p.uid].value
            val_err = validate_model(
                os.path.join(data_path, training_config['val_data']), model,
                polymath, config_file)
            if epoch_stat['best_val_err'] > val_err:
                epoch_stat['best_val_err'] = val_err
                epoch_stat['best_since'] = 0
                os.system("ls -la >> log.log")
                os.system("ls -la ./Models >> log.log")
                save_flag = True
                fail_cnt = 0
                while save_flag:
                    if fail_cnt > 100:
                        print("ERROR: failed to save models")
                        break
                    try:
                        trainer.save_checkpoint(model_file)
                        epoch_stat['record_num'] += 1
                        record_file = os.path.join(
                            model_path,
                            str(epoch_stat['record_num']) + '-' + model_name)
                        trainer.save_checkpoint(record_file)
                        save_flag = False
                    except:
                        fail_cnt = fail_cnt + 1
                for p in trainer.model.parameters:
                    p.value = temp[p.uid]
            else:
                epoch_stat['best_since'] += 1
                if epoch_stat['best_since'] > training_config['stop_after']:
                    return False

        if profiling:
            C.debugging.enable_profiler()

        return True

    if train_data_ext == '.ctf':
        mb_source, input_map = create_mb_and_map(loss, train_data_file,
                                                 polymath)

        minibatch_size = training_config['minibatch_size']  # number of samples
        epoch_size = training_config['epoch_size']

        for epoch in range(max_epochs):
            num_seq = 0
            while True:
                if trainer.total_number_of_samples_seen >= training_config[
                        'distributed_after']:
                    data = mb_source.next_minibatch(
                        minibatch_size * C.Communicator.num_workers(),
                        input_map=input_map,
                        num_data_partitions=C.Communicator.num_workers(),
                        partition_index=C.Communicator.rank())
                else:
                    data = mb_source.next_minibatch(minibatch_size,
                                                    input_map=input_map)

                trainer.train_minibatch(data)
                num_seq += trainer.previous_minibatch_sample_count
                # print_para_info(dummy, dummies_info)
                if num_seq >= epoch_size:
                    break
            if not post_epoch_work(epoch_stat):
                break
    else:
        if train_data_ext != '.tsv':
            raise Exception("Unsupported format")

        minibatch_seqs = training_config[
            'minibatch_seqs']  # number of sequences

        for epoch in range(max_epochs):  # loop over epochs
            tsv_reader = create_tsv_reader(loss, train_data_file, polymath,
                                           minibatch_seqs,
                                           C.Communicator.num_workers())
            minibatch_count = 0
            for data in tsv_reader:
                if (minibatch_count %
                        C.Communicator.num_workers()) == C.Communicator.rank():
                    trainer.train_minibatch(data)  # update model with it
                    dummy.eval()
                minibatch_count += 1
            if not post_epoch_work(epoch_stat):
                break

    if profiling:
        C.debugging.stop_profiler()
コード例 #14
0
ファイル: train_pm.py プロジェクト: Robinysh/asc18-cntk
def train(i2w,
          data_path,
          model_path,
          log_file,
          config_file,
          restore=True,
          profiling=False,
          gen_heartbeat=False):
    polymath = PolyMath(config_file)
    z, loss = polymath.model()
    training_config = importlib.import_module(config_file).training_config
    max_epochs = training_config['max_epochs']
    log_freq = training_config['log_freq']

    progress_writers = [
        C.logging.ProgressPrinter(num_epochs=max_epochs,
                                  freq=log_freq,
                                  tag='Training',
                                  log_to_file=log_file,
                                  metric_is_pct=False,
                                  rank=C.Communicator.rank(),
                                  gen_heartbeat=gen_heartbeat)
    ]

    lr = C.learning_parameter_schedule(training_config['lr'],
                                       minibatch_size=None,
                                       epoch_size=None)

    ema = {}
    dummies = []
    for p in z.parameters:
        ema_p = C.constant(0,
                           shape=p.shape,
                           dtype=p.dtype,
                           name='ema_%s' % p.uid)
        ema[p.uid] = ema_p
        dummies.append(C.reduce_sum(C.assign(ema_p,
                                             0.999 * ema_p + 0.001 * p)))
    dummy = C.combine(dummies)

    #    learner = C.adadelta(z.parameters, lr)
    learner = C.fsadagrad(
        z.parameters,
        #apply the learning rate as if it is a minibatch of size 1
        lr,
        momentum=C.momentum_schedule(
            0.9366416204111472,
            minibatch_size=training_config['minibatch_size']),
        gradient_clipping_threshold_per_sample=2.3,
        gradient_clipping_with_truncation=True)
    if C.Communicator.num_workers() > 1:
        learner = C.data_parallel_distributed_learner(learner)

    trainer = C.Trainer(z, loss, learner, progress_writers)

    if profiling:
        C.debugging.start_profiler(sync_gpu=True)

    train_data_file = os.path.join(data_path, training_config['train_data'])
    train_data_ext = os.path.splitext(train_data_file)[-1].lower()

    model_file = os.path.join(model_path, model_name)
    model = C.combine(z.outputs + loss.outputs)  #this is for validation only

    epoch_stat = {'best_val_err': 1000, 'best_since': 0, 'val_since': 0}
    print(restore, os.path.isfile(model_file))
    #    if restore and os.path.isfile(model_file):
    if restore and os.path.isfile(model_file):
        z.restore(model_file)
        #after restore always re-evaluate
        #TODO replace with rougel with external script(possibly)
        #epoch_stat['best_val_err'] = validate_model(i2w, os.path.join(data_path, training_config['val_data']), model, polymath)

    def post_epoch_work(epoch_stat):
        trainer.summarize_training_progress()
        epoch_stat['val_since'] += 1

        if epoch_stat['val_since'] == training_config['val_interval']:
            epoch_stat['val_since'] = 0
            temp = dict((p.uid, p.value) for p in z.parameters)
            for p in trainer.model.parameters:
                p.value = ema[p.uid].value
            #TODO replace with rougel with external script(possibly)
            val_err = validate_model(
                i2w, os.path.join(data_path, training_config['val_data']),
                model, polymath)
            #if epoch_stat['best_val_err'] > val_err:
            #    epoch_stat['best_val_err'] = val_err
            #    epoch_stat['best_since'] = 0
            #    trainer.save_checkpoint(model_file)
            #    for p in trainer.model.parameters:
            #        p.value = temp[p.uid]
            #else:
            #    epoch_stat['best_since'] += 1
            #    if epoch_stat['best_since'] > training_config['stop_after']:
            #        return False
            z.save(model_file)
            epoch_stat['best_since'] += 1
            if epoch_stat['best_since'] > training_config['stop_after']:
                return False
        if profiling:
            C.debugging.enable_profiler()

        return True

    init_pointer_importance = polymath.pointer_importance
    if train_data_ext == '.ctf':
        mb_source, input_map = create_mb_and_map(loss, train_data_file,
                                                 polymath)

        minibatch_size = training_config['minibatch_size']  # number of samples
        epoch_size = training_config['epoch_size']

        for epoch in range(max_epochs):
            num_seq = 0
            while True:
                if trainer.total_number_of_samples_seen >= training_config[
                        'distributed_after']:
                    data = mb_source.next_minibatch(
                        minibatch_size * C.Communicator.num_workers(),
                        input_map=input_map,
                        num_data_partitions=C.Communicator.num_workers(),
                        partition_index=C.Communicator.rank())
                else:
                    data = mb_source.next_minibatch(minibatch_size,
                                                    input_map=input_map)
                trainer.train_minibatch(data)
                num_seq += trainer.previous_minibatch_sample_count
                dummy.eval()
                if num_seq >= epoch_size:
                    break
            if not post_epoch_work(epoch_stat):
                break
            print('Before Pointer_importance:', polymath.pointer_importance)
            if polymath.pointer_importance > 0.1 * init_pointer_importance:
                polymath.pointer_importance = polymath.pointer_importance * 0.9
                print('Pointer_importance:', polymath.pointer_importance)
    else:
        if train_data_ext != '.tsv':
            raise Exception("Unsupported format")

        minibatch_seqs = training_config[
            'minibatch_seqs']  # number of sequences

        for epoch in range(max_epochs):  # loop over epochs
            tsv_reader = create_tsv_reader(loss, train_data_file, polymath,
                                           minibatch_seqs,
                                           C.Communicator.num_workers())
            minibatch_count = 0
            for data in tsv_reader:
                if (minibatch_count %
                        C.Communicator.num_workers()) == C.Communicator.rank():
                    trainer.train_minibatch(data)  # update model with it
                    dummy.eval()
                minibatch_count += 1
            if not post_epoch_work(epoch_stat):
                break

    if profiling:
        C.debugging.stop_profiler()
コード例 #15
0
def train(data_path, model_path, log_file, config_file, restore=False, profiling=False, gen_heartbeat=False):
    polymath = PolyMath(config_file)
    z, loss = polymath.model()
    training_config = importlib.import_module(config_file).training_config

    max_epochs = training_config['max_epochs']
    log_freq = training_config['log_freq']

    progress_writers = [C.logging.ProgressPrinter(
                            num_epochs = max_epochs,
                            freq = log_freq,
                            tag = 'Training',
                            log_to_file = log_file,
                            rank = C.Communicator.rank(),
                            gen_heartbeat = gen_heartbeat)]

    lr = C.learning_parameter_schedule(training_config['lr'], minibatch_size=None, epoch_size=None)

    ema = {}
    dummies = []
    for p in z.parameters:
        ema_p = C.constant(0, shape=p.shape, dtype=p.dtype, name='ema_%s' % p.uid)
        ema[p.uid] = ema_p
        dummies.append(C.reduce_sum(C.assign(ema_p, 0.999 * ema_p + 0.001 * p)))
    dummy = C.combine(dummies)

    learner = C.adadelta(z.parameters, lr)

    if C.Communicator.num_workers() > 1:
        learner = C.data_parallel_distributed_learner(learner)

    tensorboard_writer = TensorBoardProgressWriter(freq=10, log_dir='log', model=z)
    trainer = C.Trainer(z, (loss, None), learner, tensorboard_writer)

    if profiling:
        C.debugging.start_profiler(sync_gpu=True)

    train_data_file = os.path.join(data_path, training_config['train_data'])
    train_data_ext = os.path.splitext(train_data_file)[-1].lower()

    model_file = os.path.join(model_path, model_name)
    model = C.combine(list(z.outputs) + [loss.output])
    label_ab = argument_by_name(loss, 'ab')

    epoch_stat = {
        'best_val_err' : 100,
        'best_since'   : 0,
        'val_since'    : 0}

    if restore and os.path.isfile(model_file):
        trainer.restore_from_checkpoint(model_file)
        #after restore always re-evaluate
        epoch_stat['best_val_err'] = validate_model(os.path.join(data_path, training_config['val_data']), model, polymath)

    def post_epoch_work(epoch_stat):
        trainer.summarize_training_progress()
        epoch_stat['val_since'] += 1

        if epoch_stat['val_since'] == training_config['val_interval']:
            epoch_stat['val_since'] = 0
            temp = dict((p.uid, p.value) for p in z.parameters)
            for p in trainer.model.parameters:
                p.value = ema[p.uid].value
            val_err = validate_model(os.path.join(data_path, training_config['val_data']), model, polymath)
            if epoch_stat['best_val_err'] > val_err:
                epoch_stat['best_val_err'] = val_err
                epoch_stat['best_since'] = 0
                trainer.save_checkpoint(model_file)
                for p in trainer.model.parameters:
                    p.value = temp[p.uid]
            else:
                epoch_stat['best_since'] += 1
                if epoch_stat['best_since'] > training_config['stop_after']:
                    return False

        if profiling:
            C.debugging.enable_profiler()

        return True

    if train_data_ext == '.ctf':
        mb_source, input_map = create_mb_and_map(loss, train_data_file, polymath)

        minibatch_size = training_config['minibatch_size'] # number of samples
        epoch_size = training_config['epoch_size']

        for epoch in range(max_epochs):
            num_seq = 0
            while True:
                if trainer.total_number_of_samples_seen >= training_config['distributed_after']:
                    data = mb_source.next_minibatch(minibatch_size*C.Communicator.num_workers(), input_map=input_map, num_data_partitions=C.Communicator.num_workers(), partition_index=C.Communicator.rank())
                else:
                    data = mb_source.next_minibatch(minibatch_size, input_map=input_map)

                trainer.train_minibatch(data)
                num_seq += trainer.previous_minibatch_sample_count
                dummy.eval()
                if num_seq >= epoch_size:
                    break
            if not post_epoch_work(epoch_stat):
                break
    else:
        if train_data_ext != '.tsv':
            raise Exception("Unsupported format")

        minibatch_seqs = training_config['minibatch_seqs'] # number of sequences

        for epoch in range(max_epochs):       # loop over epochs
            tsv_reader = create_tsv_reader(loss, train_data_file, polymath, minibatch_seqs, C.Communicator.num_workers())
            minibatch_count = 0
            for data in tsv_reader:
                if (minibatch_count % C.Communicator.num_workers()) == C.Communicator.rank():
                    trainer.train_minibatch(data) # update model with it
                    dummy.eval()
                minibatch_count += 1
            if not post_epoch_work(epoch_stat):
                break

    if profiling:
        C.debugging.stop_profiler()