Beispiel #1
0
    def test_horovod_allreduce_cpu_gpu_error(self):
        """Test that the allreduce raises an error if different ranks try to
        perform reduction on CPU and GPU."""
        # Only do this test if there are GPUs available.
        if not torch.cuda.is_available():
            return

        hvd.init()
        rank = hvd.rank()
        size = hvd.size()

        # This test does not apply if there is only one worker.
        if size == 1:
            return

        # Same rank, different dimension
        dims = [17] * 3
        if rank % 2 == 0:
            tensor = torch.cuda.FloatTensor(*dims)
        else:
            tensor = torch.FloatTensor(*dims)

        try:
            hvd.allreduce(tensor)
            assert False, 'hvd.allreduce did not throw error'
        except torch.FatalError:
            pass
Beispiel #2
0
    def test_horovod_allreduce_error(self):
        """Test that the allreduce raises an error if different ranks try to
        send tensors of different rank or dimension."""
        hvd.init()
        rank = hvd.rank()
        size = hvd.size()

        # This test does not apply if there is only one worker.
        if size == 1:
            return

        # Same rank, different dimension
        torch.manual_seed(1234)
        dims = [17 + rank] * 3
        tensor = torch.FloatTensor(*dims).random_(-100, 100)
        try:
            hvd.allreduce(tensor)
            assert False, 'hvd.allreduce did not throw error'
        except torch.FatalError:
            pass

        # Same number of elements, different rank
        torch.manual_seed(1234)
        if rank == 0:
            dims = [17, 23 * 57]
        else:
            dims = [17, 23, 57]
        tensor = torch.FloatTensor(*dims).random_(-100, 100)
        try:
            hvd.allreduce(tensor)
            assert False, 'hvd.allreduce did not throw error'
        except torch.FatalError:
            pass
Beispiel #3
0
    def test_horovod_allreduce_average(self):
        """Test that the allreduce correctly sums 1D, 2D, 3D tensors."""
        hvd.init()
        size = hvd.size()
        dtypes = [torch.IntTensor, torch.LongTensor,
                  torch.FloatTensor, torch.DoubleTensor]
        if torch.cuda.is_available():
            dtypes += [torch.cuda.IntTensor, torch.cuda.LongTensor,
                       torch.cuda.FloatTensor, torch.cuda.DoubleTensor]
        dims = [1, 2, 3]
        for dtype, dim in itertools.product(dtypes, dims):
            torch.manual_seed(1234)
            tensor = torch.FloatTensor(*([17] * dim)).random_(-100, 100)
            tensor = tensor.type(dtype)
            averaged = hvd.allreduce(tensor, average=True)
            max_difference = averaged.data.sub(tensor).max()

            # Threshold for floating point equality depends on number of
            # ranks, since we're comparing against precise multiplication.
            if size <= 3 or dtype in [torch.IntTensor, torch.LongTensor,
                                      torch.cuda.IntTensor, torch.cuda.LongTensor]:
                threshold = 0
            elif size < 10:
                threshold = 1e-4
            elif size < 15:
                threshold = 5e-4
            else:
                break

            assert max_difference <= threshold, 'hvd.allreduce produces incorrect results'
Beispiel #4
0
    def test_horovod_allreduce_grad(self):
        """Test the correctness of the allreduce gradient."""
        hvd.init()
        size = hvd.size()
        dtypes = [torch.IntTensor, torch.LongTensor,
                  torch.FloatTensor, torch.DoubleTensor]
        if torch.cuda.is_available():
            dtypes += [torch.cuda.IntTensor, torch.cuda.LongTensor,
                       torch.cuda.FloatTensor, torch.cuda.DoubleTensor]
        dims = [1, 2, 3]
        for dtype, dim in itertools.product(dtypes, dims):
            torch.manual_seed(1234)
            tensor = torch.FloatTensor(*([17] * dim)).random_(-100, 100)
            tensor = tensor.type(dtype)
            tensor = torch.autograd.Variable(tensor, requires_grad=True)
            summed = hvd.allreduce(tensor, average=False)

            summed.backward(torch.ones([17] * dim))
            grad_out = tensor.grad.data.numpy()

            expected = np.ones([17] * dim) * size
            err = np.linalg.norm(expected - grad_out)
            self.assertLess(err, 0.00000001,
                            "gradient %s differs from expected %s, "
                            "error: %s" % (grad_out, expected, str(err)))
Beispiel #5
0
    def test_horovod_allreduce_type_error(self):
        """Test that the allreduce raises an error if different ranks try to
        send tensors of different type."""
        hvd.init()
        rank = hvd.rank()
        size = hvd.size()

        # This test does not apply if there is only one worker.
        if size == 1:
            return

        # Same rank, different dimension
        dims = [17] * 3
        if rank % 2 == 0:
            tensor = torch.IntTensor(*dims)
        else:
            tensor = torch.FloatTensor(*dims)

        try:
            hvd.allreduce(tensor)
            assert False, 'hvd.allreduce did not throw error'
        except torch.FatalError:
            pass
Beispiel #6
0
def train_once(
    sess,
    step,
    ops,
    names=None,
    gen_feed_dict_fn=None,
    deal_results_fn=None,
    interval_steps=100,
    eval_ops=None,
    eval_names=None,
    gen_eval_feed_dict_fn=None,
    deal_eval_results_fn=melt.print_results,
    valid_interval_steps=100,
    print_time=True,
    print_avg_loss=True,
    model_dir=None,
    log_dir=None,
    is_start=False,
    num_steps_per_epoch=None,
    metric_eval_fn=None,
    metric_eval_interval_steps=0,
    summary_excls=None,
    fixed_step=None,  # for epoch only, incase you change batch size
    eval_loops=1,
    learning_rate=None,
    learning_rate_patience=None,
    learning_rate_decay_factor=None,
    num_epochs=None,
    model_path=None,
    use_horovod=False,
):
    use_horovod = 'OMPI_COMM_WORLD_RANK' in os.environ
    if use_horovod:
        if FLAGS.torch:
            import horovod.torch as hvd
        else:
            import horovod.tensorflow as hvd

    #is_start = False # force not to evaluate at first step
    #print('-----------------global_step', sess.run(tf.train.get_or_create_global_step()))
    timer = gezi.Timer()
    if print_time:
        if not hasattr(train_once, 'timer'):
            train_once.timer = Timer()
            train_once.eval_timer = Timer()
            train_once.metric_eval_timer = Timer()

    melt.set_global('step', step)
    epoch = (fixed_step
             or step) / num_steps_per_epoch if num_steps_per_epoch else -1
    if not num_epochs:
        epoch_str = 'epoch:%.3f' % (epoch) if num_steps_per_epoch else ''
    else:
        epoch_str = 'epoch:%.3f/%d' % (
            epoch, num_epochs) if num_steps_per_epoch else ''
    melt.set_global('epoch', '%.2f' % (epoch))

    info = IO()
    stop = False

    if eval_names is None:
        if names:
            eval_names = ['eval/' + x for x in names]

    if names:
        names = ['train/' + x for x in names]

    if eval_names:
        eval_names = ['eval/' + x for x in eval_names]

    is_eval_step = is_start or valid_interval_steps and step % valid_interval_steps == 0
    summary_str = []

    eval_str = ''
    if is_eval_step:
        # deal with summary
        if log_dir:
            if not hasattr(train_once, 'summary_op'):
                #melt.print_summary_ops()
                if summary_excls is None:
                    train_once.summary_op = tf.summary.merge_all()
                else:
                    summary_ops = []
                    for op in tf.get_collection(tf.GraphKeys.SUMMARIES):
                        for summary_excl in summary_excls:
                            if not summary_excl in op.name:
                                summary_ops.append(op)
                    print('filtered summary_ops:')
                    for op in summary_ops:
                        print(op)
                    train_once.summary_op = tf.summary.merge(summary_ops)

                #train_once.summary_train_op = tf.summary.merge_all(key=melt.MonitorKeys.TRAIN)
                train_once.summary_writer = tf.summary.FileWriter(
                    log_dir, sess.graph)

                tf.contrib.tensorboard.plugins.projector.visualize_embeddings(
                    train_once.summary_writer, projector_config)

        # if eval ops then should have bee rank 0

        if eval_ops:
            #if deal_eval_results_fn is None and eval_names is not None:
            #  deal_eval_results_fn = lambda x: melt.print_results(x, eval_names)
            for i in range(eval_loops):
                eval_feed_dict = {} if gen_eval_feed_dict_fn is None else gen_eval_feed_dict_fn(
                )
                #eval_feed_dict.update(feed_dict)

                # if use horovod let each rant use same sess.run!
                if not log_dir or train_once.summary_op is None or gezi.env_has(
                        'EVAL_NO_SUMMARY') or use_horovod:
                    #if not log_dir or train_once.summary_op is None:
                    eval_results = sess.run(eval_ops, feed_dict=eval_feed_dict)
                else:
                    eval_results = sess.run(eval_ops + [train_once.summary_op],
                                            feed_dict=eval_feed_dict)
                    summary_str = eval_results[-1]
                    eval_results = eval_results[:-1]
                eval_loss = gezi.get_singles(eval_results)
                #timer_.print()
                eval_stop = False
                if use_horovod:
                    sess.run(hvd.allreduce(tf.constant(0)))

                #if not use_horovod or  hvd.local_rank() == 0:
                # @TODO user print should also use logging as a must ?
                #print(gezi.now_time(), epoch_str, 'eval_step: %d'%step, 'eval_metrics:', end='')
                eval_names_ = melt.adjust_names(eval_loss, eval_names)
                #if not use_horovod or hvd.rank() == 0:
                #  logging.info2('{} eval_step:{} eval_metrics:{}'.format(epoch_str, step, melt.parse_results(eval_loss, eval_names_)))
                eval_str = 'valid:{}'.format(
                    melt.parse_results(eval_loss, eval_names_))

                # if deal_eval_results_fn is not None:
                #   eval_stop = deal_eval_results_fn(eval_results)

                assert len(eval_loss) > 0
                if eval_stop is True:
                    stop = True
                eval_names_ = melt.adjust_names(eval_loss, eval_names)
                if not use_horovod or hvd.rank() == 0:
                    melt.set_global('eval_loss',
                                    melt.parse_results(eval_loss, eval_names_))

        elif interval_steps != valid_interval_steps:
            #print()
            pass

    metric_evaluate = False

    # if metric_eval_fn is not None \
    #   and (is_start \
    #     or (num_steps_per_epoch and step % num_steps_per_epoch == 0) \
    #          or (metric_eval_interval_steps \
    #              and step % metric_eval_interval_steps == 0)):
    #  metric_evaluate = True

    if metric_eval_fn is not None \
      and ((is_start or metric_eval_interval_steps \
           and step % metric_eval_interval_steps == 0) or model_path):
        metric_evaluate = True

    if 'EVFIRST' in os.environ:
        if os.environ['EVFIRST'] == '0':
            if is_start:
                metric_evaluate = False
        else:
            if is_start:
                metric_evaluate = True

    if step == 0 or 'QUICK' in os.environ:
        metric_evaluate = False

    #print('------------1step', step, 'pre metric_evaluate', metric_evaluate, hvd.rank())
    if metric_evaluate:
        # if use_horovod:
        #   print('------------metric evaluate step', step, model_path, hvd.rank())
        if not model_path or 'model_path' not in inspect.getargspec(
                metric_eval_fn).args:
            metric_eval_fn_ = metric_eval_fn
        else:
            metric_eval_fn_ = lambda: metric_eval_fn(model_path=model_path)

        try:
            l = metric_eval_fn_()
            if isinstance(l, tuple):
                num_returns = len(l)
                if num_returns == 2:
                    evaluate_results, evaluate_names = l
                    evaluate_summaries = None
                else:
                    assert num_returns == 3, 'retrun 1,2,3 ok 4.. not ok'
                    evaluate_results, evaluate_names, evaluate_summaries = l
            else:  #return dict
                evaluate_results, evaluate_names = tuple(zip(*dict.items()))
                evaluate_summaries = None
        except Exception:
            logging.info('Do nothing for metric eval fn with exception:\n',
                         traceback.format_exc())

        if not use_horovod or hvd.rank() == 0:
            #logging.info2('{} valid_step:{} {}:{}'.format(epoch_str, step, 'valid_metrics' if model_path is None else 'epoch_valid_metrics', melt.parse_results(evaluate_results, evaluate_names)))
            logging.info2('{} valid_step:{} {}:{}'.format(
                epoch_str, step, 'valid_metrics',
                melt.parse_results(evaluate_results, evaluate_names)))

        if learning_rate is not None and (learning_rate_patience
                                          and learning_rate_patience > 0):
            assert learning_rate_decay_factor > 0 and learning_rate_decay_factor < 1
            valid_loss = evaluate_results[0]
            if not hasattr(train_once, 'min_valid_loss'):
                train_once.min_valid_loss = valid_loss
                train_once.deacy_steps = []
                train_once.patience = 0
            else:
                if valid_loss < train_once.min_valid_loss:
                    train_once.min_valid_loss = valid_loss
                    train_once.patience = 0
                else:
                    train_once.patience += 1
                    logging.info2('{} valid_step:{} patience:{}'.format(
                        epoch_str, step, train_once.patience))

            if learning_rate_patience and train_once.patience >= learning_rate_patience:
                lr_op = ops[1]
                lr = sess.run(lr_op) * learning_rate_decay_factor
                train_once.deacy_steps.append(step)
                logging.info2(
                    '{} valid_step:{} learning_rate_decay by *{}, learning_rate_decay_steps={}'
                    .format(epoch_str, step, learning_rate_decay_factor,
                            ','.join(map(str, train_once.deacy_steps))))
                sess.run(tf.assign(lr_op, tf.constant(lr, dtype=tf.float32)))
                train_once.patience = 0
                train_once.min_valid_loss = valid_loss

    if ops is not None:
        #if deal_results_fn is None and names is not None:
        #  deal_results_fn = lambda x: melt.print_results(x, names)

        feed_dict = {} if gen_feed_dict_fn is None else gen_feed_dict_fn()
        # NOTICE ops[2] should be scalar otherwise wrong!! loss should be scalar
        #print('---------------ops', ops)
        if eval_ops is not None or not log_dir or not hasattr(
                train_once,
                'summary_op') or train_once.summary_op is None or use_horovod:
            feed_dict[K.learning_phase()] = 1
            results = sess.run(ops, feed_dict=feed_dict)
        else:
            ## TODO why below ?
            #try:
            feed_dict[K.learning_phase()] = 1
            results = sess.run(ops + [train_once.summary_op],
                               feed_dict=feed_dict)
            summary_str = results[-1]
            results = results[:-1]
            # except Exception:
            #   logging.info('sess.run(ops + [train_once.summary_op], feed_dict=feed_dict) fail')
            #   results = sess.run(ops, feed_dict=feed_dict)

        #print('------------results', results)
        # #--------trace debug
        # if step == 210:
        #   run_metadata = tf.RunMetadata()
        #   results = sess.run(
        #         ops,
        #         feed_dict=feed_dict,
        #         options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
        #         run_metadata=run_metadata)
        #   from tensorflow.python.client import timeline
        #   trace = timeline.Timeline(step_stats=run_metadata.step_stats)

        #   trace_file = open('timeline.ctf.json', 'w')
        #   trace_file.write(trace.generate_chrome_trace_format())

        #reults[0] assume to be train_op, results[1] to be learning_rate
        learning_rate = results[1]
        results = results[2:]

        #@TODO should support aver loss and other avg evaluations like test..
        if print_avg_loss:
            if not hasattr(train_once, 'avg_loss'):
                train_once.avg_loss = AvgScore()
            #assume results[0] as train_op return, results[1] as loss
            loss = gezi.get_singles(results)
            train_once.avg_loss.add(loss)

        steps_per_second = None
        instances_per_second = None
        hours_per_epoch = None
        #step += 1
        #if is_start or interval_steps and step % interval_steps == 0:
        interval_ok = not use_horovod or hvd.local_rank() == 0
        if interval_steps and step % interval_steps == 0 and interval_ok:
            train_average_loss = train_once.avg_loss.avg_score()
            if print_time:
                duration = timer.elapsed()
                duration_str = 'duration:{:.2f} '.format(duration)
                melt.set_global('duration', '%.2f' % duration)
                #info.write(duration_str)
                elapsed = train_once.timer.elapsed()
                steps_per_second = interval_steps / elapsed
                batch_size = melt.batch_size()
                num_gpus = melt.num_gpus()
                instances_per_second = interval_steps * batch_size / elapsed
                gpu_info = '' if num_gpus <= 1 else ' gpus:[{}]'.format(
                    num_gpus)
                if num_steps_per_epoch is None:
                    epoch_time_info = ''
                else:
                    hours_per_epoch = num_steps_per_epoch / interval_steps * elapsed / 3600
                    epoch_time_info = '1epoch:[{:.2f}h]'.format(
                        hours_per_epoch)
                    # info.write('elapsed:[{:.2f}] batch_size:[{}]{} batches/s:[{:.2f}] insts/s:[{:.2f}] {} lr:[{:.6f}]'.format(
                    #               elapsed, batch_size, gpu_info, steps_per_second, instances_per_second, epoch_time_info, learning_rate))
                    info.write(
                        'elap:[{:.2f}] batch:[{}] {} lr:[{:.6f}]'.format(
                            elapsed, batch_size, epoch_time_info,
                            learning_rate))

            if print_avg_loss:
                #info.write('train_avg_metrics:{} '.format(melt.value_name_list_str(train_average_loss, names)))
                names_ = melt.adjust_names(train_average_loss, names)
                #info.write('train_avg_metric:{} '.format(melt.parse_results(train_average_loss, names_)))
                info.write(' train:{} '.format(
                    melt.parse_results(train_average_loss, names_)))
                #info.write('train_avg_loss: {} '.format(train_average_loss))
            info.write(eval_str)
            #print(gezi.now_time(), epoch_str, 'train_step:%d'%step, info.getvalue(), end=' ')
            logging.info2('{} {} {}'.format(epoch_str, 'step:%d' % step,
                                            info.getvalue()))

            if deal_results_fn is not None:
                stop = deal_results_fn(results)

    summary_strs = gezi.to_list(summary_str)
    if metric_evaluate:
        if evaluate_summaries is not None:
            summary_strs += evaluate_summaries

    if step > 1:
        if is_eval_step:
            # deal with summary
            if log_dir:
                summary = tf.Summary()
                if eval_ops is None:
                    if train_once.summary_op is not None:
                        for summary_str in summary_strs:
                            train_once.summary_writer.add_summary(
                                summary_str, step)
                else:
                    for summary_str in summary_strs:
                        train_once.summary_writer.add_summary(
                            summary_str, step)
                    suffix = 'valid' if not eval_names else ''
                    # loss/valid
                    melt.add_summarys(summary,
                                      eval_results,
                                      eval_names_,
                                      suffix=suffix)

                if ops is not None:
                    try:
                        # loss/train_avg
                        melt.add_summarys(summary,
                                          train_average_loss,
                                          names_,
                                          suffix='train_avg')
                    except Exception:
                        pass
                    ##optimizer has done this also
                    melt.add_summary(summary, learning_rate, 'learning_rate')
                    melt.add_summary(summary,
                                     melt.batch_size(),
                                     'batch_size',
                                     prefix='other')
                    melt.add_summary(summary,
                                     melt.epoch(),
                                     'epoch',
                                     prefix='other')
                    if steps_per_second:
                        melt.add_summary(summary,
                                         steps_per_second,
                                         'steps_per_second',
                                         prefix='perf')
                    if instances_per_second:
                        melt.add_summary(summary,
                                         instances_per_second,
                                         'instances_per_second',
                                         prefix='perf')
                    if hours_per_epoch:
                        melt.add_summary(summary,
                                         hours_per_epoch,
                                         'hours_per_epoch',
                                         prefix='perf')

                if metric_evaluate:
                    #melt.add_summarys(summary, evaluate_results, evaluate_names, prefix='eval')
                    prefix = 'step_eval'
                    if model_path:
                        prefix = 'eval'
                        valid_interval_epochs = 1.
                        try:
                            valid_interval_epochs = FLAGS.valid_interval_epochs
                        except Exception:
                            pass
                        if not hasattr(train_once, 'epoch_step'):
                            train_once.epoch_step = 1 if melt.epoch(
                            ) <= 1 else int(
                                int(melt.epoch() * 10) /
                                int(valid_interval_epochs * 10))
                        else:
                            train_once.epoch_step += 1
                        step = train_once.epoch_step
                    # eval/loss eval/auc ..
                    melt.add_summarys(summary,
                                      evaluate_results,
                                      evaluate_names,
                                      prefix=prefix)

                train_once.summary_writer.add_summary(summary, step)
                train_once.summary_writer.flush()
            return stop
        elif metric_evaluate and log_dir:
            summary = tf.Summary()
            for summary_str in summary_strs:
                train_once.summary_writer.add_summary(summary_str, step)
            #summary.ParseFromString(evaluate_summaries)
            summary_writer = train_once.summary_writer
            prefix = 'step_eval'
            if model_path:
                prefix = 'eval'
                if not hasattr(train_once, 'epoch_step'):
                    ## TODO.. restart will get 1 again..
                    #epoch_step = tf.Variable(0, trainable=False, name='epoch_step')
                    #epoch_step += 1
                    #train_once.epoch_step = sess.run(epoch_step)
                    valid_interval_epochs = 1.
                    try:
                        valid_interval_epochs = FLAGS.valid_interval_epochs
                    except Exception:
                        pass
                    train_once.epoch_step = 1 if melt.epoch() <= 1 else int(
                        int(melt.epoch() * 10) /
                        int(valid_interval_epochs * 10))
                    logging.info('train_once epoch start step is',
                                 train_once.epoch_step)
                else:
                    #epoch_step += 1
                    train_once.epoch_step += 1
                step = train_once.epoch_step
            #melt.add_summarys(summary, evaluate_results, evaluate_names, prefix='eval')
            melt.add_summarys(summary,
                              evaluate_results,
                              evaluate_names,
                              prefix=prefix)
            summary_writer.add_summary(summary, step)
            summary_writer.flush()
Beispiel #7
0
def metric_average(val, name):
    tensor = torch.tensor(val)
    avg_tensor = hvd.allreduce(tensor, name=name)
    return avg_tensor.item()
Beispiel #8
0
 def barrier(self) -> None:
     # https://github.com/horovod/horovod/issues/159#issuecomment-424834603
     # hvd.allreduce(torch.tensor(0, device=self.device()), name="barrier")
     hvd.allreduce(torch.tensor(0, device="cpu"), name="barrier")
Beispiel #9
0
def train_loop(
    run_id,
    dataset_dir,
    ckpt_run_dir,
    output_dir,
    validation_only=False,
    use_cuda=False,
    light_target=False,
    seed=42,
):
    """Train loop"""
    train_epochs = 10

    math_mode = "fp16"
    rank = dist.get_rank()
    world_size = dist.get_world_size()

    # Dataset arguments
    train_global_batch_size = 2**17  # Global batch size
    max_bs = 2**13  # Max batch size for used hardware
    update_freq = int(max(1, train_global_batch_size // (max_bs * world_size)))
    max_tokens = int(train_global_batch_size // (world_size * update_freq))

    max_source_positions, max_target_positions = 80, 80
    seq_len_multiple = 2
    left_pad = (True, False)
    lang = ("en", "de")

    # specific arch
    model_args = deepcopy(DEFAULT_TRANSFORMER_ARCH)
    model_args["max_source_positions"] = max_source_positions
    model_args["max_target_positions"] = max_target_positions
    model_args["share_all_embeddings"] = True
    model_args["dropout"] = 0.1
    model_args["softmax_type"] = "fast_fill"

    lr = 1.976e-3
    optimizer_args = {
        "lr": lr,
        "eps": 1e-9,
        "betas": (0.9, 0.98),
    }
    scheduler_args = {
        "base_lr": lr,
        "warmup_init_lr": 0.0,
        "warmup_steps": 1000
    }

    loss_scaling_fp16 = {
        "init_scale": 2.0**7,
        "scale_factor": 2,
        "scale_window": 2000,
    }

    criterion_args = {"smoothing": 0.1, "fast_xentropy": True}

    # Horovod stuff
    use_horovod = (math_mode
                   == "fp16") and dist.get_backend() == dist.Backend.MPI
    if use_horovod:
        hvd.init()
        logger.info("Using horovod rank={}".format(hvd.rank()))
        tensor = torch.tensor([1])
        res = hvd.allreduce(tensor, op=hvd.Sum)
        assert res[0] == world_size

    # Load train and validation datasets
    train_set = WMT17Dataset(
        dataset_dir,
        download=True,
        train=True,
        shuffle=True,
        lang=lang,
        left_pad=left_pad,
        max_positions=(max_source_positions, max_target_positions),
        seq_len_multiple=seq_len_multiple,
    )

    validation_set = WMT17Dataset(
        dataset_dir,
        download=False,
        test=True,
        shuffle=True,
        lang=lang,
        left_pad=left_pad,
        max_positions=(max_source_positions, max_target_positions),
        seq_len_multiple=seq_len_multiple,
    )
    src_dict, trg_dict = train_set.src_dict, train_set.trg_dict

    train_batches = get_batches(train_set,
                                max_tokens=max_tokens,
                                bsz_mult=8,
                                shuffle=True,
                                seed=seed)
    val_batches = get_batches(validation_set,
                              max_tokens=max_tokens,
                              bsz_mult=8,
                              shuffle=False)

    train_batches = equalize_batches(train_batches, world_size, seed=seed)

    # Partition by rank
    train_batches = partition_dataset_by_rank(train_batches, rank, world_size)
    val_batches = partition_dataset_by_rank(val_batches, rank, world_size)

    total_train_points = sum(len(b) for b in train_batches)

    validate_every = update_freq * round(
        len(train_batches) * 0.30 / update_freq)  # Validate every 30%

    assert (validate_every % update_freq) == 0
    logger.info("Using {} total train points, {} batches".format(
        total_train_points, len(train_batches)))

    train_loader = DataLoader(
        train_set,
        num_workers=1,
        pin_memory=False,
        collate_fn=train_set.collater,
        batch_sampler=train_batches,
    )

    val_loader = DataLoader(
        validation_set,
        num_workers=1,
        pin_memory=False,
        collate_fn=validation_set.collater,
        batch_sampler=val_batches,
    )

    model = TransformerModel(Arguments(model_args), src_dict, trg_dict)
    criterion = LabelSmoothing(padding_idx=src_dict.pad(), **criterion_args)

    if use_cuda:
        model = model.cuda()
        criterion = criterion.cuda()

    fp_optimizer, optimizer, model = build_optimizer(
        model,
        optimizer_args,
        math_mode=math_mode,
        scaling_args=loss_scaling_fp16,
        use_horovod=use_horovod,
        use_cuda=use_cuda,
    )

    scheduler = SQRTTimeDecayLRWithWarmup(optimizer, **scheduler_args)

    metrics = [BLEUScore(use_raw=True)]
    checkpointer = Checkpointer(ckpt_run_dir=ckpt_run_dir,
                                rank=rank,
                                freq=CheckpointFreq.BEST)

    translator = SequenceGenerator(
        model,
        src_dict=deepcopy(src_dict),
        trg_dict=deepcopy(trg_dict),
        beam_size=4,
        stop_early=True,
        normalize_scores=True,
        len_penalty=0.6,
        sampling=False,
        sampling_topk=-1,
        minlen=1,
    )
    if not validation_only:

        if light_target:
            goal = task4_time_to_bleu_goal(20)
        else:
            goal = task4_time_to_bleu_goal(25)

        num_batches_per_device_train = len(train_loader)
        tracker = Tracker(metrics, run_id, rank, goal=goal)

        dist.barrier()
        tracker.start()

        for epoch in range(0, train_epochs):
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            model.train()
            tracker.train()

            iter_sample_size = 0
            for batch_idx, sample in enumerate(train_loader):
                tracker.batch_start()

                sample = prepare_batch(sample, use_cuda=use_cuda)
                tracker.record_batch_load()

                is_last = batch_idx == len(train_loader)
                update = (batch_idx % update_freq) == update_freq - 1
                init = (batch_idx % update_freq) == 0

                # Clear gradients in the optimizer.
                if init:
                    fp_optimizer.zero_grad()
                    iter_sample_size = 0
                    tracker.record_batch_init()

                # Compute the output
                output = model(**sample["net_input"])
                tracker.record_batch_fwd_pass()

                loss, sample_size = compute_loss(sample, output, criterion)
                loss_per_sample = loss.item() / sample_size
                iter_sample_size += sample_size
                tracker.record_batch_comp_loss()

                # Backprop
                fp_optimizer.backward_loss(loss)
                tracker.record_batch_backprop()

                if update or is_last:
                    # Get batch size over all workers
                    full_bs = get_full_batch_size(iter_sample_size,
                                                  world_size=world_size,
                                                  use_cuda=use_cuda)

                    updated = opt_step(
                        fp_optimizer,
                        tracker,
                        full_bs,
                        update_freq,
                        math_mode,
                        world_size,
                    )

                    if updated:
                        scheduler.step()

                tracker.batch_end()

                record_train_batch_stats(
                    batch_idx=batch_idx,
                    loss=loss_per_sample,
                    output=torch.Tensor([0]),
                    metric_results={},
                    tracker=tracker,
                    num_batches_per_device_train=num_batches_per_device_train,
                )

                if (batch_idx + 1) % validate_every == 0:
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()

                    metric_values, loss = validation_round(
                        val_loader,
                        metrics,
                        criterion,
                        translator,
                        tracker=tracker,
                        use_cuda=use_cuda,
                    )
                    record_validation_stats(metric_values, loss, tracker, rank)
                    if tracker.goal_reached:
                        break

                    model.train()
                    tracker.train()

            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            metric_values, loss = validation_round(
                val_loader,
                metrics,
                criterion,
                translator,
                tracker=tracker,
                use_cuda=use_cuda,
            )
            is_best = record_validation_stats(metric_values, loss, tracker,
                                              rank)
            checkpointer.save(
                tracker,
                model,
                optimizer,
                scheduler,
                tracker.current_epoch,
                is_best,
            )
            tracker.epoch_end()

            if tracker.goal_reached:
                print("Goal Reached!")
                time.sleep(10)
                return
    else:
        cecf = CheckpointsEvaluationControlFlow(
            ckpt_dir=ckpt_run_dir,
            rank=rank,
            world_size=world_size,
            checkpointer=checkpointer,
            model=model,
            epochs=train_epochs,
            loss_function=criterion,
            metrics=metrics,
            use_cuda=use_cuda,
            dtype="fp32",
            max_batch_per_epoch=None,
        )

        train_stats = cecf.evaluate_by_epochs(train_loader)
        with open(os.path.join(output_dir, "train_stats.json"), "w") as f:
            json.dump(train_stats, f)
Beispiel #10
0
def metric_sum_hvd(val, name):
    tensor = torch.tensor(val)
    avg_tensor = hvd.allreduce(tensor, name=name, average=False)
    return avg_tensor.item()
Beispiel #11
0
def evaluate(args):
    # initialize Horovod library
    hvd.init()
    # Horovod limits CPU threads to be used per worker
    torch.set_num_threads(1)

    if hvd.local_rank() == 0 and not os.path.exists(args.dir):
        # create 16 random image, mask paris for evaluation
        print(
            f"generating synthetic data to {args.dir} (this may take a while)")
        os.makedirs(args.dir)
        # set random seed to generate same random data for every node
        np.random.seed(seed=0)
        for i in range(16):
            im, seg = create_test_image_3d(128,
                                           128,
                                           128,
                                           num_seg_classes=1,
                                           channel_dim=-1)
            n = nib.Nifti1Image(im, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz"))
            n = nib.Nifti1Image(seg, np.eye(4))
            nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz"))

    images = sorted(glob(os.path.join(args.dir, "img*.nii.gz")))
    segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz")))
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)]

    # define transforms for image and segmentation
    val_transforms = Compose([
        LoadNiftid(keys=["img", "seg"]),
        AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
        ScaleIntensityd(keys="img"),
        ToTensord(keys=["img", "seg"]),
    ])

    # create a evaluation data loader
    val_ds = Dataset(data=val_files, transform=val_transforms)
    # create a evaluation data sampler
    val_sampler = DistributedSampler(val_ds,
                                     shuffle=False,
                                     num_replicas=hvd.size(),
                                     rank=hvd.rank())
    # when supported, use "forkserver" to spawn dataloader workers instead of "fork" to prevent
    # issues with Infiniband implementations that are not fork-safe
    multiprocessing_context = None
    if hasattr(
            mp, "_supports_context"
    ) and mp._supports_context and "forkserver" in mp.get_all_start_methods():
        multiprocessing_context = "forkserver"
    # sliding window inference need to input 1 image in every iteration
    val_loader = DataLoader(
        val_ds,
        batch_size=1,
        shuffle=False,
        num_workers=2,
        pin_memory=True,
        sampler=val_sampler,
        multiprocessing_context=multiprocessing_context,
    )
    dice_metric = DiceMetric(include_background=True,
                             to_onehot_y=False,
                             sigmoid=True,
                             reduction="mean")

    # create UNet, DiceLoss and Adam optimizer
    device = torch.device(f"cuda:{hvd.local_rank()}")
    model = monai.networks.nets.UNet(
        dimensions=3,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    if hvd.rank() == 0:
        # load model parameters for evaluation
        model.load_state_dict(torch.load("final_model.pth"))
    # Horovod broadcasts parameters
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)

    model.eval()
    with torch.no_grad():
        # define PyTorch Tensor to record metrics result at each GPU
        # the first value is `sum` of all dice metric, the second value is `count` of not_nan items
        metric = torch.zeros(2, dtype=torch.float, device=device)
        for val_data in val_loader:
            val_images, val_labels = val_data["img"].to(
                device), val_data["seg"].to(device)
            # define sliding window size and batch size for windows inference
            roi_size = (96, 96, 96)
            sw_batch_size = 4
            val_outputs = sliding_window_inference(val_images, roi_size,
                                                   sw_batch_size, model)
            value = dice_metric(y_pred=val_outputs, y=val_labels).squeeze()
            metric[0] += value * dice_metric.not_nans
            metric[1] += dice_metric.not_nans
        # synchronizes all processes and reduce results
        print(
            f"metric in rank {hvd.rank()}: sum={metric[0].item()}, count={metric[1].item()}"
        )
        avg_metric = hvd.allreduce(metric, name="mean_dice")
        if hvd.rank() == 0:
            print(
                f"average metric: sum={avg_metric[0].item()}, count={avg_metric[1].item()}"
            )
            print("evaluation metric:", (avg_metric[0] / avg_metric[1]).item())
Beispiel #12
0
 def average_value(self, val, name):
     avg_tensor = hvd.allreduce(val, name=name)
     return avg_tensor
 def update(self, val):
     import horovod.torch as hvd
     self.sum += hvd.allreduce(val.detach().cpu(), name=self.name)
     self.n += 1
Beispiel #14
0
 def update(self, val):
     self.sum += hvd.allreduce(val.detach().cpu(), name=self.name)
     self.n += 1
def metric_average(val, name):
    tensor = torch.tensor(val)
    avg_tensor = hvd.allreduce(tensor, name=name)
    return avg_tensor.item()
def metric_sum(value):
    return hvd.allreduce(torch.tensor(value), op=hvd.Sum).item()
def metric_ave(value):
    return hvd.allreduce(torch.tensor(value)).item()
Beispiel #18
0
 def update(self, val, delta_n=1):
     import horovod.torch as hvd
     val *= delta_n
     self.sum += hvd.allreduce(val.detach().cpu(), name=self.name)
     self.count += delta_n
Beispiel #19
0
 def avg(self):
     import horovod.torch as hvd
     if not self.synced:
         self.sum = hvd.allreduce(self.sum, name=self.name)
         self.synced = True
     return self.sum / self.count
Beispiel #20
0
def metric_average(val, name):
    tensor = torch.FloatTensor([val])
    avg_tensor = hvd.allreduce(tensor, name=name)
    return avg_tensor.data[0]
Beispiel #21
0
def run(i_run, options, train_data, valid_data, test_data, model, optimizer,
        handles, outfile):
    train_dataloader, train_idx = create_train_dataset(
        options, data_tensor_dict=train_data)
    valid_dataloader, valid_idx = create_valid_test_dataset(
        options, data_tensor_dict=valid_data)
    test_dataloader, test_idx = create_valid_test_dataset(
        options, data_tensor_dict=test_data)
    total_steps = get_train_step(len(train_data['idx']), 1, options.batchsize,
                                 hvd.size())
    train_step = 0
    best = {'val_acc': 0.0, 'epoch': 0}

    for g in optimizer.param_groups:
        g['lr'] = options.lr
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='max',
        factor=0.1,
        patience=options.lr_patience,
        threshold=0.0001,
        threshold_mode='rel',
        cooldown=0,
        min_lr=1e-2,
        eps=1e-08,
        verbose=True)

    model.reset_parameters()
    hvd.broadcast_optimizer_state(optimizer, root_rank=0)
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)

    train_epoch = 0
    for epoch in range(options.epochs):
        # train
        t0 = time.time()
        total_cla_loss, train_acc = train(train_dataloader, model, optimizer,
                                          total_steps, train_idx,
                                          handles.train_label_handle)
        t1 = time.time()

        #valid
        valid_acc = valid(valid_dataloader, model, len(valid_idx), valid_idx,
                          handles)
        t2 = time.time()
        scheduler.step(valid_acc)

        if valid_acc < best['val_acc']:
            if hvd.rank() == 0:
                print(
                    'run=%02d, epoch=%03d, loss=%.4f, train_acc=%.2f%%, valid_acc=%.2f%%,                  train_time=%.2fs, valid_time=%.2fs'
                    % (i_run, epoch, total_cla_loss, train_acc, valid_acc,
                       t1 - t0, t2 - t1))
            if epoch > best['epoch'] + options.stop_patience:
                break
        else:
            #test
            test_acc = test(test_dataloader, model, len(test_idx), test_idx,
                            handles)
            t3 = time.time()
            if hvd.rank() == 0:
                print(
                    'run=%02d, epoch=%03d, loss=%.4f, train_acc=%.2f%%, valid_acc=%.2f%%, test_acc=%.2f%%, train_time=%.2fs, valid_time=%.2fs, test_time=%.2fs'
                    % (i_run, epoch, total_cla_loss, train_acc, valid_acc,
                       test_acc, t1 - t0, t2 - t1, t3 - t2))
            best['val_acc'] = valid_acc
            best['loss'] = total_cla_loss
            best['test_acc'] = test_acc
            best['epoch'] = epoch
            best['train_acc'] = train_acc
        hvd.allreduce(torch.tensor(0))

    if hvd.rank() == 0:
        print(
            '[BEST] run=%02d, epoch=%03d, loss=%.4f, train_acc=%.2f%%, valid_acc=%.2f%%, test_acc=%.2f%%'
            % (i_run, best['epoch'], best['loss'], best['train_acc'],
               best['val_acc'], best['test_acc']))
        print(
            '[BEST] epoch=%03d, loss=%.4f, train_acc=%.2f%%, valid_acc=%.2f%%, test_acc=%.2f%%'
            % (best['epoch'], best['loss'], best['train_acc'], best['val_acc'],
               best['test_acc']),
            file=outfile)
        outfile.flush()
    return best