예제 #1
0
파일: model.py 프로젝트: richardbaihe/atec
 def __init__(self, args):
     globals().update(args.__dict__)
     random.seed(seed)
     np.random.seed(seed)
     tf.set_random_seed(seed)
     # self.ps_hosts = ps_hosts.split(',')
     # self.worker_hosts = worker_hosts.split(',')
     self.logger = ResultLogger(path=os.path.join(log_dir,
                                                  '{}.jsonl'.format(desc)),
                                **args.__dict__)
     self.text_encoder = TextEncoder(encoder_path)
     self.encoder = self.text_encoder.encoder
     self.n_vocab = len(self.text_encoder.encoder)
     self.encoder['_start_'] = len(self.encoder)
     self.encoder['_delimiter_'] = len(self.encoder)
     self.encoder['_end_'] = len(self.encoder)
     self.clf_token = self.encoder['_end_']
     self.n_special = 3
     self.n_batch_train = n_batch * n_gpu
     self.n_updates_total = n_step * 10000
    parser.add_argument('--encoder_path', type=str, default='model/encoder_bpe_40000.json')
    parser.add_argument('--bpe_path', type=str, default='model/vocab_40000.bpe')
    parser.add_argument('--n_transfer', type=int, default=12)
    parser.add_argument('--lm_coef', type=float, default=0.5)
    parser.add_argument('--b1', type=float, default=0.9)
    parser.add_argument('--b2', type=float, default=0.999)
    parser.add_argument('--e', type=float, default=1e-8)

    args = parser.parse_args()
    print(args)
    globals().update(args.__dict__)
    random.seed(seed)
    np.random.seed(seed)
    tf.set_random_seed(seed)

    logger = ResultLogger(path=os.path.join(log_dir, '{}.jsonl'.format(desc)), **args.__dict__)
    text_encoder = TextEncoder(encoder_path, bpe_path)
    encoder = text_encoder.encoder
    n_vocab = len(text_encoder.encoder)

    #(trX1, trX2, trX3, trY), (vaX1, vaX2, vaX3, vaY), (teX1, teX2, teX3) = encode_dataset(rocstories(data_dir), encoder=text_encoder)
    #enco_ry = ruoyao(data_dir)
    #(trX1,trX2,tyY), (vaX1, vaX2, vaY), (teX1, teX2) = ruoyao(data_dir)
    #print(trX1[0])
    (trX1,trX2,trY), (vaX1, vaX2, vaY), (teX1, teX2, teY) = encode_dataset(ruoyao(data_dir), encoder=text_encoder)
    n_y = 2
    encoder['_start_'] = len(encoder)
    encoder['_delimiter_'] = len(encoder)
    encoder['_classify_'] = len(encoder)
    clf_token = encoder['_classify_']
    n_special = 3
예제 #3
0
def train(sess, model, hps, logdir):
    _print(hps)
    _print('Starting training. Logging to', logdir)
    _print(
        'epoch n_processed n_images ips dtrain dtest dsample dtot train_results test_results msg'
    )

    # Train
    sess.graph.finalize()
    n_processed = 0
    n_images = 0
    train_time = 0.0
    test_loss_best = 999999

    if hvd.rank() == 0:
        train_logger = ResultLogger(logdir + "train.txt", **hps.__dict__)
        test_logger = ResultLogger(logdir + "test.txt", **hps.__dict__)

    tcurr = time.time()
    for epoch in range(1, hps.epochs):

        t = time.time()

        train_results = []
        for it in range(hps.train_its):

            # Set learning rate, linearly annealed from 0 in the first hps.epochs_warmup epochs.
            lr = hps.lr * min(1., n_processed /
                              (hps.n_train * hps.epochs_warmup))

            # Run a training step synchronously.
            _t = time.time()
            train_results += [model.train(lr)]
            if hps.verbose and hvd.rank() == 0:
                _print(n_processed, time.time() - _t, train_results[-1])
                sys.stdout.flush()

            # Images seen wrt anchor resolution
            n_processed += hvd.size() * hps.n_batch_train
            # Actual images seen at current resolution
            n_images += hvd.size() * hps.local_batch_train

        train_results = np.mean(np.asarray(train_results), axis=0)

        dtrain = time.time() - t
        ips = (hps.train_its * hvd.size() * hps.local_batch_train) / dtrain
        train_time += dtrain

        if hvd.rank() == 0:
            train_logger.log(epoch=epoch,
                             n_processed=n_processed,
                             n_images=n_images,
                             train_time=int(train_time),
                             **process_results(train_results))

        if epoch < 10 or (epoch < 50 and epoch % 10
                          == 0) or epoch % hps.epochs_full_valid == 0:
            test_results = []
            msg = ''

            t = time.time()
            # model.polyak_swap()

            if epoch % hps.epochs_full_valid == 0:
                # Full validation run
                for it in range(hps.full_test_its):
                    test_results += [model.test()]
                test_results = np.mean(np.asarray(test_results), axis=0)
                print(hps.full_test_its)
                print(test_results.shape)

                if hvd.rank() == 0:
                    test_logger.log(epoch=epoch,
                                    n_processed=n_processed,
                                    n_images=n_images,
                                    **process_results(test_results))

                    # Save checkpoint
                    if test_results[0] < test_loss_best:
                        test_loss_best = test_results[0]
                        model.save(logdir + "model_best_loss.ckpt")
                        msg += ' *'

            dtest = time.time() - t

            # Sample
            t = time.time()
            # if epoch == 1 or epoch == 10 or epoch % hps.epochs_full_sample == 0:
            #     visualise(epoch)
            dsample = time.time() - t

            if hvd.rank() == 0:
                dcurr = time.time() - tcurr
                tcurr = time.time()
                _print(
                    epoch, n_processed, n_images,
                    "{:.1f} {:.1f} {:.1f} {:.1f} {:.1f}".format(
                        ips, dtrain, dtest, dsample,
                        dcurr), train_results, test_results, msg)

            # model.polyak_swap()

    if hvd.rank() == 0:
        _print("Finished!")
예제 #4
0
def main(hps):

    # Initialize Horovod.
    hvd.init()

    # Create tensorflow session
    sess = tensorflow_session()

    # Download and load dataset.
    tf.set_random_seed(hvd.rank() + hvd.size() * hps.seed)
    np.random.seed(hvd.rank() + hvd.size() * hps.seed)

    # Get data and set train_its and valid_its
    train_iterator, test_iterator, data_init = get_data(hps, sess)
    hps.train_its, hps.test_its, hps.full_test_its = get_its(hps)

    # Create log dir
    logdir = os.path.abspath(hps.logdir) + "/"
    if not os.path.exists(logdir):
        os.mkdir(logdir)

    # Create model
    import model
    model = model.model(sess, hps, train_iterator, test_iterator, data_init)

    # Initialize visualization functions
    draw_samples = init_visualizations(hps, model, logdir)

    _print(hps)
    _print('Starting training. Logging to', logdir)
    _print('epoch n_processed n_images ips dtrain dtest dsample dtot train_results test_results msg')

    # Train
    sess.graph.finalize()
    n_processed = 0
    n_images = 0
    train_time = 0.0
    test_loss_best = 999999

    if hvd.rank() == 0:
        train_logger = ResultLogger(logdir + "train.txt", **hps.__dict__)
        test_logger = ResultLogger(logdir + "test.txt", **hps.__dict__)

    tcurr = time.time()
    for epoch in range(1, hps.epochs):

        t = time.time()

        train_results = []
        for it in range(hps.train_its):

            # Set learning rate, linearly annealed from 0 in the first hps.epochs_warmup epochs.
            lr = hps.lr * min(1., n_processed /
                              (hps.n_train * hps.epochs_warmup))

            # Run a training step synchronously.
            _t = time.time()
            train_results += [model.train(lr)]
            if hps.verbose and hvd.rank() == 0:
                _print(n_processed, time.time()-_t, train_results[-1])
                sys.stdout.flush()

            # Images seen wrt anchor resolution
            n_processed += hvd.size() * hps.n_batch_train
            # Actual images seen at current resolution
            n_images += hvd.size() * hps.local_batch_train

        train_results = np.mean(np.asarray(train_results), axis=0)

        dtrain = time.time() - t
        ips = (hps.train_its * hvd.size() * hps.local_batch_train) / dtrain
        train_time += dtrain

        if hvd.rank() == 0:
            train_logger.log(epoch=epoch, n_processed=n_processed, n_images=n_images, train_time=int(
                train_time), **process_results(train_results))

        if epoch < 10 or (epoch < 50 and epoch % 10 == 0) or epoch % hps.epochs_full_valid == 0:
            test_results = []
            msg = ''

            t = time.time()
            # model.polyak_swap()

            if epoch % hps.epochs_full_valid == 0:
                # Full validation run
                for it in range(hps.full_test_its):
                    test_results += [model.test()]
                test_results = np.mean(np.asarray(test_results), axis=0)

                if hvd.rank() == 0:
                    test_logger.log(epoch=epoch, n_processed=n_processed,
                                    n_images=n_images, **process_results(test_results))

                    # Save checkpoint
                    if test_results[0] < test_loss_best:
                        test_loss_best = test_results[0]
                        model.save(logdir+"model_best_loss.ckpt")
                        msg += ' *'

            dtest = time.time() - t

            # Sample
            t = time.time()
            if epoch == 1 or epoch == 10 or epoch % hps.epochs_full_sample == 0:
                draw_samples(epoch)
            dsample = time.time() - t

            if hvd.rank() == 0:
                dcurr = time.time() - tcurr
                tcurr = time.time()
                _print(epoch, n_processed, n_images, "{:.1f} {:.1f} {:.1f} {:.1f} {:.1f}".format(
                    ips, dtrain, dtest, dsample, dcurr), train_results, test_results, msg)

            # model.polyak_swap()

    if hvd.rank() == 0:
        _print("Finished!")
예제 #5
0
파일: train.py 프로젝트: chinatian/glow
def train(sess, model, hps, logdir, visualise):
    _print(hps)
    _print('Starting training. Logging to', logdir)
    _print('epoch n_processed n_images ips dtrain dtest dsample dtot train_results test_results msg')

    # Train
    sess.graph.finalize()
    n_processed = 0
    n_images = 0
    train_time = 0.0
    test_loss_best = 999999

    if hvd.rank() == 0:
        train_logger = ResultLogger(logdir + "train.txt", **hps.__dict__)
        test_logger = ResultLogger(logdir + "test.txt", **hps.__dict__)

    tcurr = time.time()
    for epoch in range(1, hps.epochs):

        t = time.time()

        train_results = []
        for it in range(hps.train_its):

            # Set learning rate, linearly annealed from 0 in the first hps.epochs_warmup epochs.
            lr = hps.lr * min(1., n_processed /
                              (hps.n_train * hps.epochs_warmup))

            # Run a training step synchronously.
            _t = time.time()
            train_results += [model.train(lr)]
            if hps.verbose and hvd.rank() == 0:
                _print(n_processed, time.time()-_t, train_results[-1])
                sys.stdout.flush()

            # Images seen wrt anchor resolution
            n_processed += hvd.size() * hps.n_batch_train
            # Actual images seen at current resolution
            n_images += hvd.size() * hps.local_batch_train

        train_results = np.mean(np.asarray(train_results), axis=0)

        dtrain = time.time() - t
        ips = (hps.train_its * hvd.size() * hps.local_batch_train) / dtrain
        train_time += dtrain

        if hvd.rank() == 0:
            train_logger.log(epoch=epoch, n_processed=n_processed, n_images=n_images, train_time=int(
                train_time), **process_results(train_results))

        if epoch < 10 or (epoch < 50 and epoch % 10 == 0) or epoch % hps.epochs_full_valid == 0:
            test_results = []
            msg = ''

            t = time.time()
            # model.polyak_swap()

            if epoch % hps.epochs_full_valid == 0:
                # Full validation run
                for it in range(hps.full_test_its):
                    test_results += [model.test()]
                test_results = np.mean(np.asarray(test_results), axis=0)

                if hvd.rank() == 0:
                    test_logger.log(epoch=epoch, n_processed=n_processed,
                                    n_images=n_images, **process_results(test_results))

                    # Save checkpoint
                    if test_results[0] < test_loss_best:
                        test_loss_best = test_results[0]
                        model.save(logdir+"model_best_loss.ckpt")
                        msg += ' *'

            dtest = time.time() - t

            # Sample
            t = time.time()
            if epoch == 1 or epoch == 10 or epoch % hps.epochs_full_sample == 0:
                visualise(epoch)
            dsample = time.time() - t

            if hvd.rank() == 0:
                dcurr = time.time() - tcurr
                tcurr = time.time()
                _print(epoch, n_processed, n_images, "{:.1f} {:.1f} {:.1f} {:.1f} {:.1f}".format(
                    ips, dtrain, dtest, dsample, dcurr), train_results, test_results, msg)

            # model.polyak_swap()

    if hvd.rank() == 0:
        _print("Finished!")
# submit = args.submit
# dataset = args.dataset
# n_ctx = args.n_ctx
# save_dir = args.save_dir
# desc = args.desc # data_dir = args.data_dir
# log_dir = args.log_dir
# submission_dir = args.submission_dir
# iter = args.n_iter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
print("device", device, "n_gpu", n_gpu)

# logger = ResultLogger(path=os.path.join(log_dir, '{}.jsonl'.format(desc)), **args.__dict__)
logger = ResultLogger(path=os.path.join(log_dir, '{}.jsonl'.format(desc)))

text_encoder = TextEncoder(encoder_path, bpe_path)
encoder = text_encoder.encoder
n_vocab = len(text_encoder.encoder)


print("Encoding dataset...")

((trX, trY), (vaX, vaY), _) = encode_dataset(*imdb(data_dir, n_train=100, n_valid=1000),
                                        encoder=text_encoder)


encoder['_start_'] = len(encoder)
encoder['_delimiter_'] = len(encoder)
encoder['_classify_'] = len(encoder)
예제 #7
0
    save_dir = os.path.join(args.save_dir, desc)
    data_dir = os.path.join(args.data_dir, desc)
    log_dir = os.path.join(args.log_dir, desc)
    submission_dir = args.submission_dir

    for d in (save_dir, log_dir):
        os.makedirs(d, exist_ok=True)

    dataset = args.dataset

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
    print("device", device, "n_gpu", n_gpu)

    log_file = os.path.join(log_dir, '{}.jsonl'.format(dataset))
    logger = ResultLogger(path=log_file, **args.__dict__)
    text_encoder = TextEncoder(args.encoder_path, args.bpe_path)
    encoder = text_encoder.encoder
    n_vocab = len(text_encoder.encoder)

    print("Encoding dataset...")
    ((trX, trY), (vaX, vaY), (teX, teY)) = encode_dataset(
        *preprocess_fns[dataset](data_dir, sentence_pair=args.sentence_pair),
        encoder=text_encoder,
        skip_preprocess=args.skip_preprocess)
    encoder['_start_'] = len(encoder)
    if args.sentence_pair or args.force_delimiter:
        encoder['_delimiter_'] = len(encoder)
    encoder['_classify_'] = len(encoder)
    clf_token = encoder['_classify_']
    n_special = 2 + int('_delimiter_' in encoder)
예제 #8
0
def train(sess, model, hps, logdir, visualise):
    _print(hps)
    _print('Starting training. Logging to', logdir)
    _print('epoch n_processed n_images ips dtrain dtest dsample dtot train_results test_results msg')

    # Train
    sess.graph.finalize()
    n_processed = 0
    n_images = 0
    train_time = 0.0
    test_loss_best = {'A': 999999, 'B': 999999}

    if hvd.rank() == 0:
        train_logger = {'A': ResultLogger(logdir + "train_A.txt", **hps.__dict__),
                        'B': ResultLogger(logdir + "train_B.txt", **hps.__dict__)}
        test_logger = {'A': ResultLogger(logdir + "test_A.txt", **hps.__dict__),
                       'B': ResultLogger(logdir + "test_B.txt", **hps.__dict__)}

    tcurr = time.time()
    for epoch in range(1, hps.epochs):
        t = time.time()
        train_results = {'A': [], 'B': []}
        for it in range(hps.train_its):

            # Set learning rate, linearly annealed from 0 in the first hps.epochs_warmup epochs.
            lr = hps.lr * min(1., n_processed /
                              (hps.n_train * hps.epochs_warmup))

            # Run a training step synchronously.
            _t = time.time()
            x_A, y_A, x_B, y_B = model.get_train_data()
            train_results_A, train_results_B = model.train(
                lr, x_A, y_A, x_B, y_B)
            train_results['A'] += [train_results_A]
            train_results['B'] += [train_results_B]
            if hps.verbose and hvd.rank() == 0:
                _print(n_processed, time.time()-_t, train_results['A'][-1])
                _print(n_processed, time.time()-_t, train_results['B'][-1])
                sys.stdout.flush()

            # Images seen wrt anchor resolution
            n_processed += hvd.size() * hps.n_batch_train
            # Actual images seen at current resolution
            n_images += hvd.size() * hps.local_batch_train

        train_results['A'] = np.mean(np.asarray(train_results['A']), axis=0)
        train_results['B'] = np.mean(np.asarray(train_results['B']), axis=0)

        dtrain = time.time() - t
        ips = (hps.train_its * hvd.size() * hps.local_batch_train) / dtrain
        train_time += dtrain

        if hvd.rank() == 0:
            train_logger['A'].log(epoch=epoch, n_processed=n_processed, n_images=n_images, train_time=int(
                train_time), **process_results(train_results['A']))
            train_logger['B'].log(epoch=epoch, n_processed=n_processed, n_images=n_images, train_time=int(
                train_time), **process_results(train_results['B']))

        if epoch < 10 or (epoch < 50 and epoch % 10 == 0) or epoch % hps.epochs_full_valid == 0:
            test_results = {'A': [], 'B': []}
            msg = {'A': 'A', 'B': 'B'}

            t = time.time()
            # model.polyak_swap()

            if epoch % hps.epochs_full_valid == 0:
                # Full validation run
                for it in range(hps.full_test_its):
                    x_A, y_A, x_B, y_B = model.get_test_data()
                    test_results['A'] += [model.test_A(x_A, y_A, x_B, y_B)]
                    test_results['B'] += [model.test_B(x_A, y_A, x_B, y_B)]
                test_results['A'] = np.mean(
                    np.asarray(test_results['A']), axis=0)
                test_results['B'] = np.mean(
                    np.asarray(test_results['B']), axis=0)

                if hvd.rank() == 0:
                    test_logger['A'].log(epoch=epoch, n_processed=n_processed,
                                         n_images=n_images, **process_results(test_results['A']))
                    test_logger['B'].log(epoch=epoch, n_processed=n_processed,
                                         n_images=n_images, **process_results(test_results['B']))
                    # Save checkpoint
                    if test_results['A'][0] < test_loss_best['A']:
                        test_loss_best['A'] = test_results['A'][0]
                        model.save_A(logdir+"model_A_best_loss.ckpt")
                        msg['A'] += ' *'
                    if test_results['B'][0] < test_loss_best['B']:
                        test_loss_best['B'] = test_results['B'][0]
                        model.save_B(logdir+"model_B_best_loss.ckpt")
                        msg['B'] += ' *'

            dtest = time.time() - t

            # Sample
            t = time.time()
            if epoch == 1 or epoch == 10 or epoch % hps.epochs_full_sample == 0:
                visualise(epoch)
            dsample = time.time() - t

            if hvd.rank() == 0:
                dcurr = time.time() - tcurr
                tcurr = time.time()
                msg['A'] += ', train_time: {}'.format(int(train_time))
                msg['B'] += ', train_time: {}'.format(int(train_time))
                _print(epoch, n_processed, n_images, "{:.1f} {:.1f} {:.1f} {:.1f} {:.1f}".format(
                    ips, dtrain, dtest, dsample, dcurr), train_results['A'], test_results['A'], msg['A'])
                _print(epoch, n_processed, n_images, "{:.1f} {:.1f} {:.1f} {:.1f} {:.1f}".format(
                    ips, dtrain, dtest, dsample, dcurr), train_results['B'], test_results['B'], msg['B'])
            # model.polyak_swap()

    if hvd.rank() == 0:
        _print("Finished!")
예제 #9
0
def train(sess, model, hps, logdir, visualise):
    _print(hps)
    _print('Starting training. Logging to', logdir)
    _print('\nepoch     train   [t_loss, bits_x_u, bits_x_o, bits_y, reg]   '
           'test    n_processed    n_images    (ips, dtrain, dtest, dsample, dtot), msg \n')

    # Train
    sess.graph.finalize()

    checkpoint_state_json_path = os.path.join(hps.restore_path, "last_checkpoint_state.json")

    # restore the meta of last checkpoint
    if hps.restore_path != '':
        state_dict = load_runtime_state(checkpoint_state_json_path)
        current_epoch = state_dict['current_epoch'] + 1
        n_processed = state_dict['n_processed']
        prev_train_loss = state_dict['train_loss_best']
        prev_test_loss = state_dict['test_loss_best']
        _print('Loaded the lastest checkpoint (epoch %d, n_p %d) from %s' % (current_epoch, n_processed, checkpoint_state_json_path))

    else:
        n_processed = 0
        current_epoch = 0
        prev_test_loss = None
        prev_train_loss = None

    n_images = 0
    train_time = 0.0
    test_loss_best = prev_test_loss if prev_test_loss is not None else 10.0
    train_loss_best = prev_train_loss if prev_train_loss is not None else 10.0

    if hvd.rank() == 0:
        train_logger = ResultLogger(logdir + "train.txt", **hps.__dict__)
        test_logger = ResultLogger(logdir + "test.txt", **hps.__dict__)

    tcurr = time.time()
    for epoch in range(current_epoch, hps.epochs):

        t = time.time()

        train_results = []
        for _ in range(hps.train_its):

            # Set learning rate, linearly annealed from 0 in the first hps.epochs_warmup epochs.
            lr = hps.lr * min(1., n_processed /
                              (hps.n_train * hps.epochs_warmup))

            # Run a training step synchronously.
            _t = time.time()
            train_results += [model.train(lr)]

            if hps.verbose and hvd.rank() == 0:
                _print(n_processed, time.time()-_t, train_results[-1])
                sys.stdout.flush()

            # Images seen wrt anchor resolution
            n_processed += hvd.size() * hps.n_batch_train
            # Actual images seen at current resolution
            n_images += hvd.size() * hps.local_batch_train

        train_results = np.mean(np.asarray(train_results), axis=0)

        if train_results[0] < train_loss_best:
            save_subdir = os.path.join(logdir, 'saved_models', 'best_train_loss')
            os.makedirs(save_subdir, exist_ok=True)
            train_loss_best = train_results[0]
            model.save(os.path.join(save_subdir, 'model_best_train_loss.ckpt'))
            write_runtime_state(os.path.join(save_subdir, 'last_checkpoint_state.json'), 
                {
                    'current_epoch': epoch, 
                    'n_processed': n_processed, 
                    'train_loss_best': str(train_loss_best), 
                    'test_loss_best': str(test_loss_best)
                }
                )


        dtrain = time.time() - t
        ips = (hps.train_its * hvd.size() * hps.local_batch_train) / dtrain
        train_time += dtrain

        if hvd.rank() == 0:
            train_logger.log(epoch=epoch, n_processed=n_processed, n_images=n_images, train_time=int(
                train_time), **process_results(train_results))

        if epoch < 10 or (epoch < 50 and epoch % 10 == 0) or epoch % hps.epochs_full_valid == 0:
            test_results = []
            msg = ''

            t = time.time()
            # model.polyak_swap()

            if epoch % hps.epochs_full_valid == 0:
                # Full validation run
                for _ in range(hps.full_test_its):
                    test_results += [model.test()]
                test_results = np.mean(np.asarray(test_results), axis=0)

                if hvd.rank() == 0:
                    test_logger.log(epoch=epoch, n_processed=n_processed,
                                    n_images=n_images, **process_results(test_results))

                    # Save checkpoint
                    if test_results[0] < test_loss_best:
                        save_subdir = os.path.join(logdir, 'saved_models', 'best_test_loss')
                        os.makedirs(save_subdir, exist_ok=True)
                        test_loss_best = test_results[0]
                        model.save(os.path.join(save_subdir, "model_best_loss.ckpt"))
                        write_runtime_state(os.path.join(save_subdir, 'last_checkpoint_state.json'), 
                            {
                                'current_epoch': epoch, 
                                'n_processed': n_processed, 
                                'train_loss_best': float(train_loss_best), 
                                'test_loss_best': float(test_loss_best)
                            }
                            )
                        msg += ' *'
            
            dtest = time.time() - t

            # Sample
            t = time.time()
            if epoch == 1 or epoch == 10 or epoch % hps.epochs_full_sample == 0:
                visualise(epoch)
            dsample = time.time() - t

            if hvd.rank() == 0:
                dcurr = time.time() - tcurr
                tcurr = time.time()
                train_results = list2string(train_results)
                test_results = list2string(test_results)
                _print("{:<10} [{:<20}] [{:<20}] {:>10} {:>10} ({:.1f}  {:.1f}  {:.1f}  {:.1f}  {:.1f})".format(
                    epoch, train_results,test_results, n_processed,n_images, ips,  dtrain, dtest,  dsample, dcurr), msg)

            # model.polyak_swap()

    if hvd.rank() == 0:
        _print("Finished!")