示例#1
0
def train():
    """
    Naive Multi-Device Training

    NOTE: the communicator exposes low-level interfaces

    * Parse command line arguments.
    * Instantiate a communicator and set parameter variables.
    * Specify contexts for computation.
    * Initialize DataIterator.
    * Construct a computation graph for training and one for validation.
    * Initialize solver and set parameter variables to that.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop
      * Set parameter gradients zero
      * Execute backprop.
      * AllReduce for gradients
      * Solver updates parameters by using gradients computed by backprop and all reduce.
      * Compute training error
    """
    # Parse args
    args = get_args()
    n_train_samples = 50000
    n_valid_samples = 10000
    bs_valid = args.batch_size

    # Create Communicator and Context
    extension_module = "cudnn"
    ctx = get_extension_context(extension_module, type_config=args.type_config)
    comm = C.MultiProcessDataParalellCommunicator(ctx)
    comm.init()
    n_devices = comm.size
    mpi_rank = comm.rank
    mpi_local_rank = comm.local_rank
    device_id = mpi_local_rank
    ctx.device_id = str(device_id)
    nn.set_default_context(ctx)

    # Model
    rng = np.random.RandomState(313)
    comm_syncbn = comm if args.sync_bn else None
    if args.net == "cifar10_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       rng=rng,
                                       ncls=10,
                                       nmaps=32,
                                       act=F.relu,
                                       comm=comm_syncbn)
        data_iterator = data_iterator_cifar10
    if args.net == "cifar100_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       rng=rng,
                                       ncls=100,
                                       nmaps=384,
                                       act=F.elu,
                                       comm=comm_syncbn)
        data_iterator = data_iterator_cifar100

    # Create training graphs
    image_train = nn.Variable((args.batch_size, 3, 32, 32))
    label_train = nn.Variable((args.batch_size, 1))
    pred_train = prediction(image_train, test=False)
    pred_train.persistent = True
    loss_train = (loss_function(pred_train, label_train) /
                  n_devices).apply(persistent=True)
    error_train = F.mean(F.top_n_error(pred_train, label_train,
                                       axis=1)).apply(persistent=True)
    loss_error_train = F.sink(loss_train, error_train)
    input_image_train = {"image": image_train, "label": label_train}

    # Create validation graph
    image_valid = nn.Variable((bs_valid, 3, 32, 32))
    label_valid = nn.Variable((args.batch_size, 1))
    pred_valid = prediction(image_valid, test=True)
    error_valid = F.mean(F.top_n_error(pred_valid, label_valid, axis=1))
    input_image_valid = {"image": image_valid, "label": label_valid}

    # Solvers
    solver = S.Adam()
    solver.set_parameters(nn.get_parameters())
    base_lr = args.learning_rate
    warmup_iter = int(
        1. * n_train_samples / args.batch_size / n_devices) * args.warmup_epoch
    warmup_slope = base_lr * (n_devices - 1) / warmup_iter
    solver.set_learning_rate(base_lr)

    # Create monitor
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10)
    monitor_verr = MonitorSeries("Validation error", monitor, interval=1)
    monitor_vtime = MonitorTimeElapsed("Validation time", monitor, interval=1)

    # Data Iterator
    rng = np.random.RandomState(device_id)
    _, tdata = data_iterator(args.batch_size, True, rng)
    vsource, vdata = data_iterator(args.batch_size, False)

    # loss_error_train.forward()

    # Training-loop
    ve = nn.Variable()
    for i in range(int(args.max_iter / n_devices)):
        # Validation
        if i % int(n_train_samples / args.batch_size / n_devices) == 0:
            ve_local = 0.
            k = 0
            idx = np.random.permutation(n_valid_samples)
            val_images = vsource.images[idx]
            val_labels = vsource.labels[idx]
            for j in range(int(n_valid_samples / n_devices * mpi_rank),
                           int(n_valid_samples / n_devices * (mpi_rank + 1)),
                           bs_valid):
                image = val_images[j:j + bs_valid]
                label = val_labels[j:j + bs_valid]
                if len(image
                       ) != bs_valid:  # note that smaller batch is ignored
                    continue
                input_image_valid["image"].d = image
                input_image_valid["label"].d = label
                error_valid.forward(clear_buffer=True)
                ve_local += error_valid.d.copy()
                k += 1
            ve_local /= k
            ve.d = ve_local
            comm.all_reduce(ve.data, division=True, inplace=True)

            # Save model
            if device_id == 0:
                monitor_verr.add(i * n_devices, ve.d.copy())
                monitor_vtime.add(i * n_devices)
                if i % int(args.model_save_interval / n_devices) == 0:
                    nn.save_parameters(
                        os.path.join(args.model_save_path,
                                     'params_%06d.h5' % i))

        # Forward/Zerograd
        image, label = tdata.next()
        input_image_train["image"].d = image
        input_image_train["label"].d = label
        loss_error_train.forward(clear_no_need_grad=True)
        solver.zero_grad()

        # Backward/AllReduce
        backward_and_all_reduce(
            loss_error_train,
            comm,
            with_all_reduce_callback=args.with_all_reduce_callback)

        # Solvers update
        solver.update()

        # Linear Warmup
        if i <= warmup_iter:
            lr = base_lr + warmup_slope * i
            solver.set_learning_rate(lr)

        if device_id == 0:  # loss and error locally, and elapsed time
            monitor_loss.add(i * n_devices, loss_train.d.copy())
            monitor_err.add(i * n_devices, error_train.d.copy())
            monitor_time.add(i * n_devices)

        # exit(0)

    if device_id == 0:
        nn.save_parameters(
            os.path.join(args.model_save_path,
                         'params_%06d.h5' % (args.max_iter / n_devices)))
示例#2
0
def train():
    """
    Main script.

    Steps:

    * Parse command line arguments.
    * Instantiate a communicator and set parameter variables.
    * Specify contexts for computation.
    * Initialize DataIterator.
    * Construct a computation graph for training and one for validation.
    * Initialize solver and set parameter variables to that.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop
      * Set parameter gradients zero
      * Execute backprop.
      * Solver updates parameters by using gradients computed by backprop.
      * Compute training error
    """
    # Parse args
    args = get_args()
    n_train_samples = 50000
    bs_valid = args.batch_size
    extension_module = args.context
    ctx = get_extension_context(extension_module,
                                device_id=args.device_id,
                                type_config=args.type_config)
    nn.set_default_context(ctx)
    if args.net == "cifar10_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       ncls=10,
                                       nmaps=64,
                                       act=F.relu)
        data_iterator = data_iterator_cifar10
    if args.net == "cifar100_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       ncls=100,
                                       nmaps=384,
                                       act=F.elu)
        data_iterator = data_iterator_cifar100

    # Create training graphs
    test = False
    image_train = nn.Variable((args.batch_size, 3, 32, 32))
    label_train = nn.Variable((args.batch_size, 1))
    pred_train = prediction(image_train, test)
    loss_train = loss_function(pred_train, label_train)
    input_image_train = {"image": image_train, "label": label_train}

    # Create validation graph
    test = True
    image_valid = nn.Variable((bs_valid, 3, 32, 32))
    pred_valid = prediction(image_valid, test)
    input_image_valid = {"image": image_valid}

    # Solvers
    solver = S.Adam()
    solver.set_parameters(nn.get_parameters())

    # Create monitor
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10)
    monitor_verr = MonitorSeries("Test error", monitor, interval=1)

    # Data Iterator
    tdata = data_iterator(args.batch_size, True)
    vdata = data_iterator(args.batch_size, False)

    # Training-loop
    for i in range(args.max_iter):
        # Validation
        if i % int(n_train_samples / args.batch_size) == 0:
            ve = 0.
            for j in range(args.val_iter):
                image, label = vdata.next()
                input_image_valid["image"].d = image
                pred_valid.forward()
                ve += categorical_error(pred_valid.d, label)
            ve /= args.val_iter
            monitor_verr.add(i, ve)
        if int(i % args.model_save_interval) == 0:
            nn.save_parameters(
                os.path.join(args.model_save_path, 'params_%06d.h5' % i))

        # Forward/Zerograd/Backward
        image, label = tdata.next()
        input_image_train["image"].d = image
        input_image_train["label"].d = label
        loss_train.forward()
        solver.zero_grad()
        loss_train.backward()

        # Solvers update
        solver.update()

        e = categorical_error(pred_train.d, input_image_train["label"].d)
        monitor_loss.add(i, loss_train.d.copy())
        monitor_err.add(i, e)
        monitor_time.add(i)

    nn.save_parameters(
        os.path.join(args.model_save_path, 'params_%06d.h5' % (args.max_iter)))
示例#3
0
def train():
    """
    Naive Multi-Device Training

    NOTE: the communicator exposes low-level interfaces

    * Parse command line arguments.
    * Instantiate a communicator and set parameter variables.
    * Specify contexts for computation.
    * Initialize DataIterator.
    * Construct a computation graph for training and one for validation.
    * Initialize solver and set parameter variables to that.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop
      * Set parameter gradients zero
      * Execute backprop.
      * Solver updates parameters by using gradients computed by backprop.
      * Compute training error
    """
    # Parse args
    args = get_args()
    n_train_samples = 50000
    bs_valid = args.batch_size
    rng = np.random.RandomState(313)
    if args.net == "cifar10_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       rng=rng,
                                       ncls=10,
                                       nmaps=64,
                                       act=F.relu)
        data_iterator = data_iterator_cifar10
    if args.net == "cifar100_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       rng=rng,
                                       ncls=100,
                                       nmaps=384,
                                       act=F.elu)
        data_iterator = data_iterator_cifar100

    # Communicator and Context
    extension_module = "cuda.cudnn"
    ctx = extension_context(extension_module)
    comm = C.MultiProcessDataParalellCommunicator(ctx)
    comm.init()
    n_devices = comm.size
    mpi_rank = comm.rank
    mpi_local_rank = comm.local_rank
    device_id = mpi_local_rank
    ctx = extension_context(extension_module, device_id=device_id)
    nn.set_default_context(ctx)

    # Create training graphs
    test = False
    image_train = nn.Variable((args.batch_size, 3, 32, 32))
    label_train = nn.Variable((args.batch_size, 1))
    pred_train = prediction(image_train, test)
    loss_train = loss_function(pred_train, label_train)
    input_image_train = {"image": image_train, "label": label_train}

    # add parameters to communicator
    comm.add_context_and_parameters((ctx, nn.get_parameters()))

    # Create validation graph
    test = True
    image_valid = nn.Variable((bs_valid, 3, 32, 32))
    pred_valid = prediction(image_valid, test)
    input_image_valid = {"image": image_valid}

    # Solvers
    solver = S.Adam()
    solver.set_parameters(nn.get_parameters())
    base_lr = args.learning_rate
    warmup_iter = int(
        1. * n_train_samples / args.batch_size / n_devices) * args.warmup_epoch
    warmup_slope = base_lr * (n_devices - 1) / warmup_iter
    solver.set_learning_rate(base_lr)

    # Create monitor
    from nnabla.monitor import Monitor, MonitorSeries, MonitorTimeElapsed
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10)
    monitor_verr = MonitorSeries("Test error", monitor, interval=10)

    # Data Iterator
    rng = np.random.RandomState(device_id)
    tdata = data_iterator(args.batch_size, True, rng)
    vdata = data_iterator(args.batch_size, False)

    # Training-loop
    for i in range(int(args.max_iter / n_devices)):
        # Validation
        if device_id == 0:
            if i % int(n_train_samples / args.batch_size / n_devices) == 0:
                ve = 0.
                for j in range(args.val_iter):
                    image, label = vdata.next()
                    input_image_valid["image"].d = image
                    pred_valid.forward()
                    ve += categorical_error(pred_valid.d, label)
                ve /= args.val_iter
                monitor_verr.add(i * n_devices, ve)
            if i % int(args.model_save_interval / n_devices) == 0:
                nn.save_parameters(
                    os.path.join(args.model_save_path, 'params_%06d.h5' % i))

        # Forward/Zerograd/Backward
        image, label = tdata.next()
        input_image_train["image"].d = image
        input_image_train["label"].d = label
        loss_train.forward()
        solver.zero_grad()
        loss_train.backward()

        # Allreduce
        comm.allreduce(division=False, inplace=False)

        # Solvers update
        solver.update()

        # Linear Warmup
        if i <= warmup_iter:
            lr = base_lr + warmup_slope * i
            solver.set_learning_rate(lr)

        if device_id == 0:
            e = categorical_error(pred_train.d, input_image_train["label"].d)
            monitor_loss.add(i * n_devices, loss_train.d.copy())
            monitor_err.add(i * n_devices, e)
            monitor_time.add(i * n_devices)

    if device_id == 0:
        nn.save_parameters(
            os.path.join(args.model_save_path,
                         'params_%06d.h5' % (args.max_iter / n_devices)))
示例#4
0
optimizer = optim.Adam(model.parameters(), lr=0.001)
# since the cuda draws segmentation error in random eps with unknown reasons, 
#i complete the training by executing the code repeatively 
th.loadModel('p1_latest.pth', model, optimizer)
model.train()
train_MSE = np.load("hw3_data/p1_plot_npy/train_MSE.npy").tolist()
train_KLD = np.load("hw3_data/p1_plot_npy/train_KLD.npy").tolist()

#train_MSE = []
#train_KLD = []
while len(train_MSE) < 200:
    print('Epoch:',len(train_MSE))
    MSE_loss, KLD_loss = 0.0, 0.0
    for batch_idx, data in enumerate(trainset_loader):
        data = data.to(device)
        data = Variable(data)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss, KLD, MSE = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        optimizer.step()
        MSE_loss += float(MSE.data)
        KLD_loss += float(KLD.data)  
    print("training Recon Loss:",MSE_loss/(12288*len(trainset)))
    print("training KLD_loss:", KLD_loss/len(trainset))
    print('')
    train_MSE.append(MSE_loss/(12288*len(trainset)))
    train_KLD.append(KLD_loss/len(trainset))
    th.saveModel('p1_latest.pth',model,optimizer)
    np.save("hw3_data/p1_plot_npy/train_MSE.npy", train_MSE)
    np.save("hw3_data/p1_plot_npy/train_KLD.npy", train_KLD)
示例#5
0
        # initializing the hidden state for each batch
        # because the captions are not related from image to image
        hidden = decoder.reset_state(batch_size=target.shape[0])

        dec_input = tf.expand_dims([tokenizer.word_index['<start>']] *
                                   BATCH_SIZE, 1)

        with tf.GradientTape() as tape:
            features = encoder(img_tensor)

            for i in range(1, target.shape[1]):
                # passing the features through the decoder
                predictions, hidden, _ = decoder(dec_input, features, hidden)

                loss += loss_function(target[:, i], predictions)

                # using teacher forcing
                dec_input = tf.expand_dims(target[:, i], 1)

        total_loss += (loss / int(target.shape[1]))

        variables = encoder.variables + decoder.variables

        gradients = tape.gradient(loss, variables)

        optimizer.apply_gradients(zip(gradients, variables),
                                  tf.train.get_or_create_global_step())

        if batch % 100 == 0:
            checkpoint.save(file_prefix=checkpoint_prefix)
示例#6
0
def train(args):
    """
    Multi-Device Training

    NOTE: the communicator exposes low-level interfaces

    Steps:
    * Instantiate a communicator and set parameter variables.
    * Specify contexts for computation.
    * Initialize DataIterator.
    * Construct a computation graph for training and one for validation.
    * Initialize solver and set parameter variables to that.
    * Load checkpoint to resume previous training.
    * Create monitor instances for saving and displaying training stats.
    * Training loop
      * Computate error rate for validation data (periodically)
      * Get a next minibatch.
      * Execute forwardprop
      * Set parameter gradients zero
      * Execute backprop.
      * AllReduce for gradients
      * Solver updates parameters by using gradients computed by backprop and all reduce.
      * Compute training error
    """
    # Create Communicator and Context
    comm = create_communicator(ignore_error=True)
    if comm:
        n_devices = comm.size
        mpi_rank = comm.rank
        device_id = comm.local_rank
    else:
        n_devices = 1
        mpi_rank = 0
        device_id = args.device_id

    if args.context == 'cpu':
        import nnabla_ext.cpu
        context = nnabla_ext.cpu.context()
    else:
        import nnabla_ext.cudnn
        context = nnabla_ext.cudnn.context(device_id=device_id)
    nn.set_default_context(context)

    n_train_samples = 50000
    n_valid_samples = 10000
    bs_valid = args.batch_size
    iter_per_epoch = int(n_train_samples / args.batch_size / n_devices)

    # Model
    rng = np.random.RandomState(313)
    comm_syncbn = comm if args.sync_bn else None
    if args.net == "cifar10_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       rng=rng,
                                       ncls=10,
                                       nmaps=64,
                                       act=F.relu,
                                       comm=comm_syncbn)
        data_iterator = data_iterator_cifar10
    if args.net == "cifar100_resnet23":
        prediction = functools.partial(resnet23_prediction,
                                       rng=rng,
                                       ncls=100,
                                       nmaps=384,
                                       act=F.elu,
                                       comm=comm_syncbn)
        data_iterator = data_iterator_cifar100

    # Create training graphs
    image_train = nn.Variable((args.batch_size, 3, 32, 32))
    label_train = nn.Variable((args.batch_size, 1))
    pred_train = prediction(image_train, test=False)
    pred_train.persistent = True
    loss_train = (loss_function(pred_train, label_train) /
                  n_devices).apply(persistent=True)
    error_train = F.mean(F.top_n_error(pred_train, label_train,
                                       axis=1)).apply(persistent=True)
    loss_error_train = F.sink(loss_train, error_train)

    # Create validation graphs
    image_valid = nn.Variable((bs_valid, 3, 32, 32))
    label_valid = nn.Variable((bs_valid, 1))
    pred_valid = prediction(image_valid, test=True)
    error_valid = F.mean(F.top_n_error(pred_valid, label_valid, axis=1))

    # Solvers
    solver = S.Adam()
    solver.set_parameters(nn.get_parameters())
    base_lr = args.learning_rate
    warmup_iter = iter_per_epoch * args.warmup_epoch
    warmup_slope = base_lr * (n_devices - 1) / warmup_iter
    solver.set_learning_rate(base_lr)

    # load checkpoint if file exist.
    start_point = 0
    if args.use_latest_checkpoint:
        files = glob.glob(f'{args.model_save_path}/checkpoint_*.json')
        if len(files) != 0:
            index = max([
                int(n) for n in
                [re.sub(r'.*checkpoint_(\d+).json', '\\1', f) for f in files]
            ])
            # load weights and solver state info from specified checkpoint file.
            start_point = load_checkpoint(
                f'{args.model_save_path}/checkpoint_{index}.json', solver)
        print(f'checkpoint is loaded. start iteration from {start_point}')

    # Create monitor
    monitor = Monitor(args.monitor_path)
    monitor_loss = MonitorSeries("Training loss", monitor, interval=10)
    monitor_err = MonitorSeries("Training error", monitor, interval=10)
    monitor_time = MonitorTimeElapsed("Training time", monitor, interval=10)
    monitor_verr = MonitorSeries("Validation error", monitor, interval=1)
    monitor_vtime = MonitorTimeElapsed("Validation time", monitor, interval=1)

    # Data Iterator

    # If the data does not exist, it will try to download it from the server
    # and prepare it. When executing multiple processes on the same host, it is
    # necessary to execute initial data preparation by the representative
    # process (rank is 0) on the host.

    # Download dataset by rank-0 process
    if single_or_rankzero():
        rng = np.random.RandomState(mpi_rank)
        _, tdata = data_iterator(args.batch_size, True, rng)
        vsource, vdata = data_iterator(bs_valid, False)

    # Wait for data to be prepared without watchdog
    if comm:
        comm.barrier()

    # Prepare dataset for remaining process
    if not single_or_rankzero():
        rng = np.random.RandomState(mpi_rank)
        _, tdata = data_iterator(args.batch_size, True, rng)
        vsource, vdata = data_iterator(bs_valid, False)

    # Training-loop
    ve = nn.Variable()
    for i in range(start_point // n_devices, args.epochs * iter_per_epoch):
        # Validation
        if i % iter_per_epoch == 0:
            ve_local = 0.
            k = 0
            idx = np.random.permutation(n_valid_samples)
            val_images = vsource.images[idx]
            val_labels = vsource.labels[idx]
            for j in range(int(n_valid_samples / n_devices * mpi_rank),
                           int(n_valid_samples / n_devices * (mpi_rank + 1)),
                           bs_valid):
                image = val_images[j:j + bs_valid]
                label = val_labels[j:j + bs_valid]
                if len(image
                       ) != bs_valid:  # note that smaller batch is ignored
                    continue
                image_valid.d = image
                label_valid.d = label
                error_valid.forward(clear_buffer=True)
                ve_local += error_valid.d.copy()
                k += 1
            ve_local /= k
            ve.d = ve_local
            if comm:
                comm.all_reduce(ve.data, division=True, inplace=True)

            # Monitoring error and elapsed time
            if single_or_rankzero():
                monitor_verr.add(i * n_devices, ve.d.copy())
                monitor_vtime.add(i * n_devices)

        # Save model
        if single_or_rankzero():
            if i % (args.model_save_interval // n_devices) == 0:
                iter = i * n_devices
                nn.save_parameters(
                    os.path.join(args.model_save_path,
                                 'params_%06d.h5' % iter))
                if args.use_latest_checkpoint:
                    save_checkpoint(args.model_save_path, iter, solver)

        # Forward/Zerograd
        image, label = tdata.next()
        image_train.d = image
        label_train.d = label
        loss_error_train.forward(clear_no_need_grad=True)
        solver.zero_grad()

        # Backward/AllReduce
        backward_and_all_reduce(
            loss_error_train,
            comm,
            with_all_reduce_callback=args.with_all_reduce_callback)

        # Solvers update
        solver.update()

        # Linear Warmup
        if i <= warmup_iter:
            lr = base_lr + warmup_slope * i
            solver.set_learning_rate(lr)

        # Monitoring loss, error and elapsed time
        if single_or_rankzero():
            monitor_loss.add(i * n_devices, loss_train.d.copy())
            monitor_err.add(i * n_devices, error_train.d.copy())
            monitor_time.add(i * n_devices)

    # Save nnp last epoch
    if single_or_rankzero():
        runtime_contents = {
            'networks': [{
                'name': 'Validation',
                'batch_size': args.batch_size,
                'outputs': {
                    'y': pred_valid
                },
                'names': {
                    'x': image_valid
                }
            }],
            'executors': [{
                'name': 'Runtime',
                'network': 'Validation',
                'data': ['x'],
                'output': ['y']
            }]
        }
        iter = args.epochs * iter_per_epoch
        nn.save_parameters(
            os.path.join(args.model_save_path, 'params_%06d.h5' % iter))
        nnabla.utils.save.save(
            os.path.join(args.model_save_path, f'{args.net}_result.nnp'),
            runtime_contents)
    if comm:
        comm.barrier()
def train(all_models, training_models, solver, training_params, log_every,
          **kwargs):
    model, c_model, actor = all_models
    k_steps = kwargs["k"]
    num_epochs = kwargs["n_epochs"]
    batch_size = kwargs["batch_size"]
    N = kwargs["N"]
    c_type = kwargs["c_type"]
    vae_weight = kwargs["vae_w"]
    beta = kwargs["vae_b"]

    # Configure experiment path
    savepath = kwargs['savepath']

    conditional = kwargs["conditional"]

    configure('%s/var_log' % savepath, flush_secs=5)

    ### Load data ### -- assuming appropriate npy format
    data_file = kwargs["data_dir"]
    data = np.load(data_file)
    n_trajs = len(data)
    data_size = sum([len(data[i]) - k_steps for i in range(n_trajs)])
    print('Number of trajectories: %d' % n_trajs)  # 315
    print('Number of transitions: %d' % data_size)  # 378315

    test_file = kwargs["test_dir"]
    test_data = np.load(test_file)
    test_context = get_torch_images_from_numpy(test_data,
                                               conditional,
                                               one_image=True)

    ### Train models ###
    c_loss = vae_loss = a_loss = torch.Tensor([0]).cuda()
    for epoch in range(num_epochs):
        n_batch = int(data_size / batch_size)
        print('********** Epoch %i ************' % epoch)
        for it in range(n_batch):
            idx, t = get_idx_t(batch_size, k_steps, n_trajs, data)
            o, c = get_torch_images_from_numpy(data[idx, t], conditional)
            ks = np.random.choice(k_steps, batch_size)
            o_next, _ = get_torch_images_from_numpy(data[idx, t + ks],
                                                    conditional)
            o_neg = get_negative_examples(
                data, idx, batch_size, N,
                conditional) if kwargs["use_o_neg"] else None
            o_pred, mu, logvar, cond_info = model(o, c)
            o_next_pred, _, _, _ = model(o_next, c)

            # VAE loss
            if model in training_models:
                vae_loss = loss_function(o_pred,
                                         o,
                                         mu,
                                         logvar,
                                         cond_info.get("means_cond", None),
                                         cond_info.get("log_var_cond", None),
                                         beta=beta) * vae_weight
                vae_loss.backward()

            # C loss
            if c_model in training_models and epoch >= kwargs["pretrain"]:
                c_loss = get_c_loss(model, c_model, c_type, o_pred,
                                    o_next_pred, c, N, o_neg)
                c_loss.backward()

            # Actor loss
            if actor in training_models and epoch >= kwargs["pretrain"]:
                a = get_torch_actions(data[idx, t + 1])
                a_loss = actor.loss(a, o, o_next, c)
                a_loss.backward()

            ### Update models ###
            if solver is not None:
                solver.step()
            reset_grad(training_params)

            if it % log_every == 0:
                ### Log info ###
                log_info(c_loss, vae_loss, a_loss, model, conditional,
                         cond_info, it, n_batch, epoch)

                ### Save params ###
                if not os.path.exists('%s/var' % savepath):
                    os.makedirs('%s/var' % savepath)
                torch.save(model.state_dict(),
                           '%s/var/vae-%d-last-5' % (savepath, epoch % 5 + 1))
                torch.save(c_model.state_dict(),
                           '%s/var/cpc-%d-last-5' % (savepath, epoch % 5 + 1))
                torch.save(
                    actor.state_dict(),
                    '%s/var/actor-%d-last-5' % (savepath, epoch % 5 + 1))

                ### Log images ###
                with torch.no_grad():
                    n_contexts = 7
                    n_samples_per_c = 8
                    o_distinct_c = get_negative_examples(
                        data, idx[:n_contexts], n_contexts, n_samples_per_c,
                        conditional)
                    log_images(
                        o[:n_contexts], o_pred[:n_contexts],
                        o_distinct_c.reshape(n_samples_per_c, n_contexts,
                                             *o_distinct_c.size()[1:]),
                        c[:n_contexts], test_context, model, c_model,
                        n_contexts, n_samples_per_c, savepath, epoch)