Пример #1
0
def train_function(parallel_index):
    global global_t
    global stop_requested

    training_thread = training_threads[parallel_index]

    checkpoint_count = 0

    while True:
        if stop_requested:
            break
        if global_t > args.max_time_step:
            break

        diff_global_t = training_thread.process(sess, global_t, summary_writer,
                                                record_score_fn)
        # score_summary_op, score_input)
        global_t += diff_global_t

        # global checkpoint saving
        if parallel_index == 0 and checkpoints.should_save_checkpoint(
                global_t, 1000000, checkpoint_count):
            checkpoint_count += 1
            print "Saving checkpoint %d at t=%d" % (checkpoint_count, global_t)
            checkpoints.save_checkpoint(checkpoint_name=checkpoint_name)
Пример #2
0
def train_function(parallel_index):
    global global_t
    global stop_requested

    training_thread = training_threads[parallel_index]

    checkpoint_count = 0

    while True:
        if stop_requested:
            break
        if global_t > args.max_time_step:
            break

        diff_global_t = training_thread.process(sess, global_t, summary_writer,
                                                record_score_fn)
        # score_summary_op, score_input)
        global_t += diff_global_t


        # global checkpoint saving
        if parallel_index == 0 and checkpoints.should_save_checkpoint(global_t, 1000000, checkpoint_count):
            checkpoint_count += 1
            print "Saving checkpoint %d at t=%d" % (checkpoint_count, global_t)
            checkpoints.save_checkpoint(checkpoint_name=checkpoint_name)
Пример #3
0
    def save_checkpoint(self, name, path):

        run_update_var(self._session, self._training_timestep_var, self._training_timestep.value)
        print "saving checkpoint for training timestep %d" % self._training_timestep_var.eval(session=self._session)
        checkpoints.save_checkpoint(self._session,
                                    saver=self._saver,
                                    #global_step=self._training_timestep.value,
                                    global_step=self._training_timestep_var,
                                    checkpoint_name=name,
                                    path=path)
Пример #4
0
def train( start_epoch, end_epoch, model, criterion, optimizer, device, X_train, X_valid, y_train, y_valid):
    for epoch in range(start_epoch, end_epoch):
        optimizer.zero_grad() # Clears existing gradients from previous epoch
        X_train.to(device)
        output = model(X_train)
        loss = criterion(output, y_train)
        loss.backward() # Does backpropagation and calculates gradients
        optimizer.step() # Updates the weights accordingly
        accuracy = test(X_valid, y_valid, model)
        print_loss_accuracy(epoch, loss.item(), accuracy,every = 10)
        save_checkpoint( { "epoch": epoch+1, "state_dict": model.state_dict(), "optimizer": optimizer.state_dict()}, every = 25)
Пример #5
0
    def save_checkpoint(self, name, path):

        run_update_var(self._session, self._training_timestep_var,
                       self._training_timestep.value)
        print "saving checkpoint for training timestep %d" % self._training_timestep_var.eval(
            session=self._session)
        checkpoints.save_checkpoint(
            self._session,
            saver=self._saver,
            #global_step=self._training_timestep.value,
            global_step=self._training_timestep_var,
            checkpoint_name=name,
            path=path)
Пример #6
0
def main():
    parser = Parser()
    config = parser.config

    for param, value in config.__dict__.items():
        print(param + '.' * (50 - len(param) - len(str(value))) + str(value))
    print()

    # Load previous checkpoint if it exists
    checkpoint = load_latest(config)

    # Create model
    model = load_model(config, checkpoint)

    # print number of parameters in the model
    n_params = sum([param.view(-1).size()[0] for param in model.parameters()])
    print('Total number of parameters: \33[91m{}\033[0m'.format(n_params))

    # Load train and test data
    train_loader, valid_loader, test_loader = Loader(config)

    n_batches = int(len(train_loader.dataset.train_data) / config.batch_size)

    # save the configuration
    with open(os.path.join(config.save, 'log.txt'), 'w') as file:
        json.dump('json_stats: ' + str(config.__dict__), file)

    # Instantiate the criterion, optimizer and learning rate scheduler
    criterion = torch.nn.CrossEntropyLoss(size_average=True)

    optimizer = torch.optim.SGD(model.parameters(),
                                lr=config.LR,
                                momentum=config.momentum,
                                weight_decay=config.weight_decay,
                                nesterov=config.nesterov)

    start_time = 0
    if checkpoint is not None:
        start_epoch = checkpoint['time'] + 1
        optimizer.load_state_dict(checkpoint['optimizer'])

    if config.lr_shape == 'multistep':
        scheduler = MultiStepLR(optimizer, milestones=[81, 122], gamma=0.1)
    elif config.lr_shape == 'cosine':
        if checkpoint is not None:
            scheduler = checkpoint['scheduler']
        else:
            scheduler = CosineAnnealingRestartsLR(optimizer,
                                                  1,
                                                  config.T_e,
                                                  T_mul=config.T_mul)

    # The trainer handles the training loop and evaluation on validation set
    trainer = Trainer(model, criterion, config, optimizer, scheduler)

    epoch = 1

    while True:
        # Train for a single epoch
        train_top1, train_loss, stop_training = trainer.train(
            epoch, train_loader)

        # Run model on the validation and test set
        valid_top1 = trainer.evaluate(epoch, valid_loader, 'valid')
        test_top1 = trainer.evaluate(epoch, test_loader, 'test')

        current_time = time.time()

        results = {
            'epoch': epoch,
            'time': current_time,
            'train_top1': train_top1,
            'valid_top1': valid_top1,
            'test_top1': test_top1,
            'train_loss': float(train_loss.data),
        }

        with open(os.path.join(config.save, 'results.txt'), 'w') as file:
            json.dump(str(results), file)
            file.write('\n')

        print(
            '==> Finished epoch %d (budget %.3f): %7.3f (train) %7.3f (validation) %7.3f (test)'
            % (epoch, config.budget, train_top1, valid_top1, test_top1))

        if stop_training:
            break

        epoch += 1

    if start_time >= config.budget:
        trainer.evaluate(epoch, test_loader, 'test')
    else:
        save_checkpoint(int(config.budget), trainer.model, trainer.optimizer,
                        trainer.scheduler, config)
Пример #7
0
def train(
    name,
    model,
    training_data,
    optimizer,
    device,
    epochs,
    validation_data=None,
    tb=None,
    log_interval=100,
    interpolate_interval=1,
    interpolate_data=None,
    start_epoch=0,
    start_batch=0,
    total_loss=0,
    n_char_total=0,
    n_char_correct=0,
    run_batches=0,
    best_valid_accu=0.0,
    best_valid_loss=float("Inf"),
    best_interpolate_accu=0.0,
    best_interpolate_loss=float("Inf"),
    run_max_batches=None,
    extrapolate_data=None,
    checkpoint=True,
    lr=None,
    warmup_lr=None,
    warmup_interval=None,
    smoothing=False,
):
    print("~~~ Beginning Training ~~~~")
    print(
        f"Start epoch: {start_epoch}, Start batch: {start_batch}, Max batch: {run_max_batches}"
    )

    for epoch_i in range(start_epoch, epochs):

        start = time.time()

        print(
            f"[ Epoch: {epoch_i} / {epochs}, Run Batch: {run_batches} / {run_max_batches}]"
        )

        train_loss, train_accu, new_batch_count, done = train_epoch(
            model=model,
            name=name,
            training_data=training_data,
            optimizer=optimizer,
            device=device,
            epoch=epoch_i,
            tb=tb,
            log_interval=log_interval,
            max_batches=run_max_batches,
            run_batch_count=run_batches,
            start_batch=start_batch,
            total_loss=total_loss,
            n_char_total=n_char_total,
            n_char_correct=n_char_correct,
            lr=lr,
            warmup_lr=warmup_lr,
            warmup_interval=warmup_interval,
            smoothing=smoothing,
        )

        run_batches = new_batch_count

        print(
            "[Training]  loss: {train_loss}, ppl: {ppl: 8.6f}, accuracy: {accu:3.3f} %, "
            "elapse: {elapse:3.3f}ms".format(
                train_loss=train_loss,
                ppl=math.exp(min(train_loss, 100)),
                accu=100 * train_accu,
                elapse=(time.time() - start) * 1000,
            ))

        if not utils.is_preempted():
            inference_datasets = {}
            if interpolate_data:
                inference_datasets["interpolate"] = interpolate_data
            if extrapolate_data:
                inference_datasets["extrapolate"] = extrapolate_data

            for group, dataset in inference_datasets.items():
                start = time.time()
                inference_loss, inference_acc = inference_epoch(
                    model,
                    dataset,
                    device,
                    epoch_i,
                    group,
                    tb,
                    log_interval,
                )
                print(
                    "[{group}]  loss: {inference_loss},  ppl: {ppl: 8.5f}, accuracy: {accu:3.3f} %, "
                    "elapse: {elapse:3.3f}ms".format(
                        group=group,
                        inference_loss=inference_loss,
                        ppl=math.exp(min(inference_loss, 100)),
                        accu=100 * inference_acc,
                        elapse=(time.time() - start) * 1000,
                    ))

        if done or checkpoint:
            print("Building checkpoint..")
            start = time.time()
            state = build_checkpoint(
                name=name,
                model=model,
                optimizer=optimizer,
                acc=train_accu,
                loss=train_loss,
                epoch=epoch_i if not done and checkpoint else epoch_i + 1,
                run_batches=run_batches,
                # is_preempted=utils.is_preempted(),
                start_batch=0,
                lr=lr,
            )

            if utils.is_cloud():
                print("Saving to google cloud..")
                checkpoint_name = "checkpoint"
                if done:
                    checkpoint_name = (
                        f"{checkpoint_name}_b{run_batches}_e{epoch_i}_complete"
                    )
                elif checkpoint:
                    checkpoint_name = f"{checkpoint_name}_b{run_batches}_e{epoch_i}"

                save_checkpoint(
                    state=state,
                    name=checkpoint_name,
                    path="./checkpoints",
                )
            else:
                rotating_save_checkpoint(
                    state,
                    prefix=f"{name}_{run_batches}_training",
                    path="./checkpoints",
                    nb=5,
                )
            print(f"Save checkpoint time: {(time.time() - start) * 1000}ms")
            # if utils.is_preempted():
            #     print("Completed preemption handling. Cleanly exiting.")
            #     sys.exit(0)

        if done:
            print(
                f"Reached max batch. Breaking out of training at the end of epoch {epoch_i}"
            )
            break

        start_batch = 0
        training_data.dataset.endEpoch()
        training_data.dataset.shuffleData()

    print("~~~~~~ Completed training ~~~~~~")

    if utils.is_cloud():
        print("Shutting down instance")
        os.system("sudo shutdown -h now")
Пример #8
0
def train_epoch(
    model,
    name,
    training_data,
    optimizer,
    device,
    epoch,
    tb=None,
    log_interval=100,
    max_batches=None,
    run_batch_count=0,
    start_batch=0,
    total_loss=0,
    n_char_total=0,
    n_char_correct=0,
    lr=None,
    warmup_lr=None,
    warmup_interval=None,
    smoothing=False,
):

    training_iter = iter(training_data)

    if start_batch > 0:
        last_question = np_encode_string(
            training_data.dataset.__getitem__(-1)["q"])
        print(f"Final question before checkpoint was {last_question}")

    model.train()
    # interrupted_batch = None
    done = False

    loss_per_char = 0
    accuracy = 0

    for batch_idx, batch in enumerate(training_iter, start=start_batch):
        if utils.is_preempted():
            print("Exiting...")
            sys.exit(0)

        if warmup_interval is not None and batch_idx == warmup_interval:
            print(
                f"End of warmup. Swapping learning rates from {warmup_lr} to {lr}"
            )
            for param_group in optimizer.param_groups:
                warmup_lr = lr
                param_group["lr"] = lr

        batch_qs, batch_qs_pos, batch_as, batch_as_pos = map(
            lambda x: x.to(device), batch)

        gold_as = batch_as[:, 1:]

        optimizer.zero_grad()
        pred_as = model(batch_qs, batch_qs_pos, batch_as, batch_as_pos)

        loss, n_correct = compute_performance(pred_as,
                                              gold_as,
                                              smoothing=smoothing)

        loss.backward()

        # Clip gradients, paper uses 0.1
        clip_grad_value_(model.parameters(), 0.1)

        # update parameters
        optimizer.step()

        # note keeping
        total_loss += loss.item()

        non_pad_mask = gold_as.ne(Constants.PAD)
        n_char = non_pad_mask.sum().item()
        n_char_total += n_char
        n_char = n_char if n_char > 1 else 1

        batch_loss = loss / n_char
        loss_per_char = total_loss / n_char_total

        n_char_correct += n_correct
        batch_acc = n_correct / n_char
        accuracy = n_char_correct / n_char_total
        print(
            f"Batch: {batch_idx}. Acc: {accuracy:.6f}. Loss: {loss_per_char:.6f}. Batch_acc: {batch_acc:.6f}. Batch_loss: {batch_loss:.6f} "
        )

        # TODO: automatically trim the TB logs that go beyond the preempted checkpoint
        if tb is not None and batch_idx % log_interval == 0:
            tb.add_scalars(
                {
                    "loss_per_char": loss_per_char,
                    "accuracy": accuracy,
                    "batch_loss": batch_loss,
                    "batch_acc": batch_acc,
                },
                group="train",
                sub_group="batch",
                global_step=run_batch_count,
            )

        run_batch_count += 1

        if max_batches is not None and run_batch_count == max_batches:
            print(
                f"Reached {run_batch_count} batches on max_batches of {max_batches}. Breaking from epoch."
            )
            # interrupted_batch = batch_idx
            done = True
            break

        if batch_idx % 251 == 0 and batch_idx != 0:
            print(
                f"Checkpointing on batch: {batch_idx}. Accuracy: {accuracy}. Loss per char: {loss_per_char}. Time: {time.time()}"
            )
            print(f"Last question is {batch_qs[-1]}")

            state = build_checkpoint(
                name=name,
                model=model,
                optimizer=optimizer,
                acc=accuracy,
                loss=loss_per_char,
                epoch=epoch,
                run_batches=run_batch_count,
                start_batch=batch_idx + 1,
                total_loss=total_loss,
                n_char_total=n_char_total,
                n_char_correct=n_char_correct,
                lr=warmup_lr,
            )

            save_checkpoint(state=state,
                            name=f"{name}_latest_checkpoint",
                            path="./checkpoints")

        # if utils.is_preempted():
        #     print(
        #         f"Preemption at end of Epoch batch: {batch_idx} and new Run batch: {run_batch_count}. Breaking from epoch."
        #     )
        #     interrupted_batch = batch_idx
        #     break

    if tb is not None and not utils.is_preempted():
        tb.add_scalars(
            {
                "loss_per_char": loss_per_char,
                "accuracy": accuracy
            },
            group="train",
            sub_group="epoch",
            global_step=epoch,
        )

    return loss_per_char, accuracy, run_batch_count, done
Пример #9
0
    # run vis on main thread always...yes, much better
    if not args.unity_env: # no need for vis thread if using unity
        vis_network_thread(global_network)
except Exception as e:
    print "exception in vis thread, exiting"
    stop_requested = True
    print e




# signal.pause() # vis thread effectively pauses!




# if visualize_global:
#     vis_thread.join()
#
# else:

for t in train_threads:
    t.join()



print("*****************************************************************************************")
print('Now saving data. Please wait')

checkpoints.save_checkpoint(checkpoint_name)
Пример #10
0
    def train(self, epoch, train_loader):
        '''
            Trains the model for a single epoch
        '''

        train_size = int(0.9 * len(train_loader.dataset.train_data) /
                         self.config.batch_size)
        top1_sum, loss_sum, last_loss = 0.0, 0.0, 0.0
        N = 0

        print('\33[1m==> Training epoch # {}\033[0m'.format(str(epoch)))

        self.model.train()

        start_time_step = time.time()
        for step, (data, targets) in enumerate(train_loader):
            data_timer = time.time()

            if self.use_cuda:
                data = data.cuda()
                targets = targets.cuda()

            data, targets_a, targets_b, lam = mixup_data(
                data, targets, self.config.alpha, self.use_cuda)

            data = Variable(data)
            targets_a = Variable(targets_a)
            targets_b = Variable(targets_b)

            batch_size = data.size(0)

            if epoch != 1 or step != 0:
                self.scheduler.step(epoch=self.scheduler.cumulative_time)
            else:
                self.scheduler.step()

            # used for SGDR with seconds as budget
            start_time_batch = time.time()

            self.optimizer.zero_grad()

            outputs = self.model(data)
            loss_func = mixup_criterion(targets_a, targets_b, lam)
            loss = loss_func(self.criterion, outputs)
            loss.backward()

            self.optimizer.step()

            # used for SGDR with seconds as budget
            delta_time = time.time() - start_time_batch
            self.scheduler.cumulative_time += delta_time
            self.cumulative_time += delta_time
            self.scheduler.last_step = self.scheduler.cumulative_time - delta_time - 1e-10

            # Each time before the learning rate restarts we save a checkpoint in order to create snapshot ensembles after the training finishes
            if (epoch != 1 or step != 0) and (
                    self.cumulative_time > self.config.T_e + delta_time +
                    5) and (self.scheduler.last_step < 0):
                save_checkpoint(int(round(self.cumulative_time)), self.model,
                                self.optimizer, self.scheduler, self.config)

            top1 = self.compute_score_train(outputs, targets_a, targets_b,
                                            batch_size, lam)
            top1_sum += top1 * batch_size
            last_loss = loss
            loss_sum += loss * batch_size
            N += batch_size

            #print(' | Epoch: [%d][%d/%d]   Time %.3f  Data %.3f  Err %1.3f  top1 %7.2f  lr %.4f'
            #        % (epoch, step + 1, train_size, self.cumulative_time, data_timer - start_time_step, loss.data, top1, self.scheduler.get_lr()[0]))

            start_time_step = time.time()

            if self.cumulative_time >= self.config.budget:
                print(
                    ' * Stopping at Epoch: [%d][%d/%d] for a budget of %.3f s'
                    % (epoch, step + 1, train_size, self.config.budget))
                return top1_sum / N, loss_sum / N, True

        return top1_sum / N, loss_sum / N, False
Пример #11
0
def train(model, data, dtype, args):
    iter_count = 0
    disc_solver = model.disc_optimizer(args['learning_rate'], args['beta1'],
                                       args['beta2'], args['weight_decay'],
                                       args['optim'])
    gen_solver = model.gen_optimizer(args['learning_rate'], args['beta1'],
                                     args['beta2'], args['weight_decay'],
                                     args['optim'])

    starting_epoch = 0
    if args['resume']:
        checkpoints.load_checkpoint(model, CHECKPOINTS_DIR, args['resume'])
        starting_epoch = args['resume'] // len(data)
        iter_count = args['resume']
        print('Loading checkpoint of {name} at {chkpt_iter}'.format(
            name=model.name, chkpt_iter=args['resume']))

    for epoch in tqdm(range(starting_epoch, args['num_epochs']),
                      desc='epochs',
                      position=1):
        for batch_index, batch in tqdm(enumerate(data),
                                       desc='iterations',
                                       position=2,
                                       total=len(data)):
            if args['resume'] and batch_index < iter_count % len(data):
                continue
            try:
                x, _ = batch
            except ValueError:
                x = batch

            x = x.type(dtype)  # Batch data

            # Calculate number of discriminator iterations
            disc_iterations = args['disc_warmup_iterations'] if (
                batch_index < args['disc_warmup_length']
                or batch_index % args['disc_rapid_train_interval']
                == 0) else args['disc_iterations']

            # Train discriminator
            for _ in range(disc_iterations):
                disc_solver.zero_grad()
                real_data = model.preprocess_data(x).type(dtype)
                noise = model.sample_noise(args['batch_size']).type(dtype)
                disc_loss, fake_images = model.disc_loss(
                    real_data, noise, True)
                disc_loss_gp = disc_loss + model.gradient_penalty(
                    x, fake_images, lambda_val=args['lambda_val'])
                disc_loss_gp.backward()
                disc_solver.step()

            # Train generator
            gen_solver.zero_grad()
            noise = model.sample_noise(args['batch_size']).type(dtype)
            gen_loss = model.gen_loss(noise)
            gen_loss.backward()
            gen_solver.step()

            if iter_count % args['losses_every'] == 0:
                reporter.visualize_scalar(
                    -disc_loss.item(),
                    'Wasserstein distance between x and g',
                    iteration=iter_count,
                    env=model.name)
                reporter.visualize_scalar(gen_loss.item(),
                                          'Generator loss',
                                          iteration=iter_count,
                                          env=model.name)

            # send sample images to the visdom server.
            if iter_count % args['images_every'] == 0:
                reporter.visualize_images(
                    model.sample_images(args['sample_size']).data,
                    'generated samples {}'.format(iter_count //
                                                  args['images_every']),
                    env=model.name)

            if iter_count % args['checkpoint_every'] == 0 and iter_count != 0:
                checkpoints.save_checkpoint(model, CHECKPOINTS_DIR, iter_count,
                                            args)

            iter_count += 1
        args['resume'] = None
    samples = model.sample_images(args['sample_size']).data
    reporter.visualize_images(samples,
                              'final generated samples',
                              env=model.name)
    checkpoints.save_images(samples, args['tag'])
    checkpoints.save_checkpoint(model, FINALS_DIR, iter_count, args)
Пример #12
0
signal.signal(signal.SIGINT, signal_handler)

print('Press Ctrl+C to stop')

try:
    # run vis on main thread always...yes, much better
    if not args.unity_env:  # no need for vis thread if using unity
        vis_network_thread(global_network)
except Exception as e:
    print "exception in vis thread, exiting"
    stop_requested = True
    print e

# signal.pause() # vis thread effectively pauses!

# if visualize_global:
#     vis_thread.join()
#
# else:

for t in train_threads:
    t.join()

print(
    "*****************************************************************************************"
)
print('Now saving data. Please wait')

checkpoints.save_checkpoint(checkpoint_name)