Ejemplo n.º 1
0
def train():
    child_params = get_child_model_params()
    controller_params = get_controller_params()
    corpus = data.Corpus(child_params['data_dir'])
    eval_batch_size = child_params['eval_batch_size']

    train_data = batchify(corpus.train, child_params['batch_size'],
                          child_params['cuda'])
    val_data = batchify(corpus.valid, eval_batch_size, child_params['cuda'])
    ntokens = len(corpus.dictionary)

    if os.path.exists(os.path.join(child_params['model_dir'], 'model.pt')):
        print("Found model.pt in {}, automatically continue training.".format(
            os.path.join(child_params['model_dir'])))
        continue_train_child = True
    else:
        continue_train_child = False

    if continue_train_child:
        child_model = torch.load(
            os.path.join(child_params['model_dir'], 'model.pt'))
    else:
        child_model = model_search.RNNModelSearch(
            ntokens, child_params['emsize'], child_params['nhid'],
            child_params['nhidlast'], child_params['dropout'],
            child_params['dropouth'], child_params['dropoutx'],
            child_params['dropouti'], child_params['dropoute'],
            child_params['drop_path'])

    if os.path.exists(os.path.join(controller_params['model_dir'],
                                   'model.pt')):
        print("Found model.pt in {}, automatically continue training.".format(
            os.path.join(child_params['model_dir'])))
        continue_train_controller = True
    else:
        continue_train_controller = False

    if continue_train_controller:
        controller_model = torch.load(
            os.path.join(controller_params['model_dir'], 'model.pt'))
    else:
        controller_model = controller.Controller(controller_params)

    size = 0
    for p in child_model.parameters():
        size += p.nelement()
    logging.info('child model param size: {}'.format(size))
    size = 0
    for p in controller_model.parameters():
        size += p.nelement()
    logging.info('controller model param size: {}'.format(size))

    if args.cuda:
        if args.single_gpu:
            parallel_child_model = child_model.cuda()
            parallel_controller_model = controller_model.cuda()
        else:
            parallel_child_model = nn.DataParallel(child_model, dim=1).cuda()
            parallel_controller_model = nn.DataParallel(controller_model,
                                                        dim=1).cuda()
    else:
        parallel_child_model = child_model
        parallel_controller_model = controller_model

    total_params = sum(x.data.nelement() for x in child_model.parameters())
    logging.info('Args: {}'.format(args))
    logging.info('Child Model total parameters: {}'.format(total_params))
    total_params = sum(x.data.nelement()
                       for x in controller_model.parameters())
    logging.info('Args: {}'.format(args))
    logging.info('Controller Model total parameters: {}'.format(total_params))

    # Loop over epochs.

    if continue_train_child:
        optimizer_state = torch.load(
            os.path.join(child_params['model_dir'], 'optimizer.pt'))
        if 't0' in optimizer_state['param_groups'][0]:
            child_optimizer = torch.optim.ASGD(
                child_model.parameters(),
                lr=child_params['lr'],
                t0=0,
                lambd=0.,
                weight_decay=child_params['wdecay'])
        else:
            child_optimizer = torch.optim.SGD(
                child_model.parameters(),
                lr=child_params['lr'],
                weight_decay=child_params['wdecay'])
        child_optimizer.load_state_dict(optimizer_state)
        child_epoch = torch.load(
            os.path.join(child_params['model_dir'], 'misc.pt'))['epoch'] - 1
    else:
        child_optimizer = torch.optim.SGD(child_model.parameters(),
                                          lr=child_params['lr'],
                                          weight_decay=child_params['wdecay'])
        child_epoch = 0

    if continue_train_controller:
        optimizer_state = torch.load(
            os.path.join(controller_params['model_dir'], 'optimizer.pt'))
        controller_optimizer = torch.optim.Adam(
            controller_model.parameters(),
            lr=controller_params['lr'],
            weight_decay=controller_params['weight_decay'])
        controller_optimizer.load_state_dict(optimizer_state)
        controller_epoch = torch.load(
            os.path.join(controller_params['model_dir'],
                         'misc.pt'))['epoch'] - 1
    else:
        controller_optimizer = torch.optim.Adam(
            controller_model.parameters(),
            lr=controller_params['lr'],
            weight_decay=controller_params['weight_decay'])
        controller_epoch = 0
    eval_every_epochs = child_params['eval_every_epochs']
    while True:
        # Train child model
        if child_params['arch_pool'] is None:
            arch_pool = generate_arch(
                controller_params['num_seed_arch'])  #[[arch]]
            child_params['arch_pool'] = arch_pool
        child_params['arch'] = None

        if isinstance(eval_every_epochs, int):
            child_params['eval_every_epochs'] = eval_every_epochs
        else:
            eval_every_epochs = list(map(int, eval_every_epochs))
            for index, e in enumerate(eval_every_epochs):
                if child_epoch < e:
                    child_params['eval_every_epochs'] = e
                    break

        for e in range(child_params['eval_every_epochs']):
            child_epoch += 1
            model_search.train(train_data, child_model, parallel_child_model,
                               child_optimizer, child_params, child_epoch)
            if child_epoch % child_params['eval_every_epochs'] == 0:
                save_checkpoint(child_model, child_optimizer, child_epoch,
                                child_params['model_dir'])
                logging.info('Saving Model!')
            if child_epoch >= child_params['train_epochs']:
                break

        # Evaluate seed archs
        valid_accuracy_list = model_search.evaluate(val_data, child_model,
                                                    parallel_child_model,
                                                    child_params,
                                                    eval_batch_size)

        # Output archs and evaluated error rate
        old_archs = child_params['arch_pool']
        old_archs_perf = valid_accuracy_list

        old_archs_sorted_indices = np.argsort(old_archs_perf)
        old_archs = np.array(old_archs)[old_archs_sorted_indices].tolist()
        old_archs_perf = np.array(
            old_archs_perf)[old_archs_sorted_indices].tolist()
        with open(
                os.path.join(child_params['model_dir'],
                             'arch_pool.{}'.format(child_epoch)), 'w') as fa:
            with open(
                    os.path.join(child_params['model_dir'],
                                 'arch_pool.perf.{}'.format(child_epoch)),
                    'w') as fp:
                with open(os.path.join(child_params['model_dir'], 'arch_pool'),
                          'w') as fa_latest:
                    with open(
                            os.path.join(child_params['model_dir'],
                                         'arch_pool.perf'), 'w') as fp_latest:
                        for arch, perf in zip(old_archs, old_archs_perf):
                            arch = ' '.join(map(str, arch))
                            fa.write('{}\n'.format(arch))
                            fa_latest.write('{}\n'.format(arch))
                            fp.write('{}\n'.format(perf))
                            fp_latest.write('{}\n'.format(perf))

        if child_epoch >= child_params['train_epochs']:
            logging.info('Training finished!')
            break

        # Train Encoder-Predictor-Decoder
        # [[arch]]
        encoder_input = list(map(lambda x: parse_arch_to_seq(x), old_archs))
        encoder_target = normalize_target(old_archs_perf)
        decoder_target = copy.copy(encoder_input)
        controller_params['batches_per_epoch'] = math.ceil(
            len(encoder_input) / controller_params['batch_size'])
        controller_epoch = controller.train(encoder_input, encoder_target,
                                            decoder_target, controller_model,
                                            parallel_controller_model,
                                            controller_optimizer,
                                            controller_params,
                                            controller_epoch)

        # Generate new archs
        new_archs = []
        controller_params['predict_lambda'] = 0
        top100_archs = list(
            map(lambda x: parse_arch_to_seq(x), old_archs[:100]))
        max_step_size = controller_params['max_step_size']
        while len(new_archs) < controller_params['max_new_archs']:
            controller_params['predict_lambda'] += 1
            new_arch = controller.infer(top100_archs, controller_model,
                                        parallel_controller_model,
                                        controller_params)
            for arch in new_arch:
                if arch not in encoder_input and arch not in new_archs:
                    new_archs.append(arch)
                if len(new_archs) >= controller_params['max_new_archs']:
                    break
            logging.info('{} new archs generated now'.format(len(new_archs)))
            if controller_params['predict_lambda'] >= max_step_size:
                break
        #[[arch]]
        new_archs = list(map(lambda x: parse_seq_to_arch(x),
                             new_archs))  #[[arch]]
        num_new_archs = len(new_archs)
        logging.info("Generate {} new archs".format(num_new_archs))
        random_new_archs = generate_arch(50)
        new_arch_pool = old_archs[:len(old_archs) - num_new_archs -
                                  50] + new_archs + random_new_archs
        logging.info("Totally {} archs now to train".format(
            len(new_arch_pool)))
        child_params['arch_pool'] = new_arch_pool
        with open(os.path.join(child_params['model_dir'], 'arch_pool'),
                  'w') as f:
            for arch in new_arch_pool:
                arch = ' '.join(map(str, arch))
                f.write('{}\n'.format(arch))
Ejemplo n.º 2
0
def train(train_data, dev_data):

    assert args.batch_size % args.small_batch_size == 0, 'batch_size must be divisible by small_batch_size'
    ntokens = len(vocab.word2id)
    # Turn on training mode which enables dropout.
    total_loss = 0
    total_valid_loss = 0
    start_time = time.time()
    # ntokens = len(corpus.dictionary)

    # batch, i = 0, 0
    for batch in range(len(train_data)):
        train_batch = train_data.next_batch()
        dev_batch = dev_data.next_batch()
        # for batch, (train_batch, dev_batch) in enumerate(zip(train_data, dev_data)):
        # hidden = [model.init_hidden(args.small_batch_size) for _ in range(args.batch_size // args.small_batch_size)]
        # hidden_valid = [model.init_hidden(args.small_batch_size) for _ in
        #                 range(args.batch_size // args.small_batch_size)]

        #print('hidden shape: {} | hidden valid: {} |'.format(hidden.shape, hidden_valid.shape))
        # while i < train_data.size(0) - 1 - 1:
        bptt = args.bptt if np.random.random() < 0.95 else args.bptt / 2.
        # Prevent excessively small or negative sequence lengths
        # seq_len = max(5, int(np.random.normal(bptt, 5)))
        # # There's a very small chance that it could select a very long sequence length resulting in OOM
        # seq_len = min(seq_len, args.bptt + args.max_seq_len_delta)
        # seq_len = int(bptt)

        lr2 = optimizer.param_groups[0]['lr']
        optimizer.param_groups[0]['lr'] = lr2  #* seq_len / args.bptt
        model.train()

        # data_valid, targets_valid = get_batch(search_data, i % (search_data.size(0) - 1), args)
        # data, targets = get_batch(train_data, i, args, seq_len=seq_len)

        optimizer.zero_grad()

        # start, end, s_id = 0, args.small_batch_size, 0
        cur_data = train_batch
        cur_targets = train_batch['relation']

        cur_data_valid = dev_batch
        cur_targets_valid = dev_batch['relation']

        hidden = model.init_hidden(len(train_batch['relation']))[0]
        hidden_valid = model.init_hidden(len(dev_batch['relation']))[0]
        # print('Train Batch Shapes: | Hidden: {} | Tokens: {} |'.format(hidden.shape, cur_data['tokens'].shape))
        # print('Dev Batch Shapes: | Hidden: {} | Tokens: {} |'.format(hidden_valid.shape, cur_data_valid['tokens'].shape))
        assert hidden.shape[1] == cur_data['tokens'].shape[
            0], 'Hidden shape: {} | tokens shape: {}'.format(
                hidden.shape, cur_data['tokens'].shape)
        assert hidden_valid.shape[1] == cur_data_valid['tokens'].shape[
            0], 'Hidden shape: {} | tokens shape: {}'.format(
                hidden_valid.shape, cur_data_valid['tokens'].shape)

        # while start < args.batch_size:
        #     cur_data, cur_targets = data[:, start: end], targets[:, start: end].contiguous().view(-1)
        #     cur_data_valid, cur_targets_valid = data_valid[:, start: end], targets_valid[:, start: end].contiguous().view(-1)

        # Starting each batch, we detach the hidden state from how it was previously produced.
        # If we didn't, the model would try backpropagating all the way to start of the dataset.
        # hidden[s_id] = repackage_hidden(hidden[s_id])
        # hidden_valid[s_id] = repackage_hidden(hidden_valid[s_id])
        #print(hidden.shape)
        #hidden = repackage_hidden(hidden)
        #hidden_valid = repackage_hidden(hidden_valid)

        # hidden_valid[s_id], grad_norm = architect.step(
        #         hidden[s_id], cur_data, cur_targets,
        #         hidden_valid[s_id], cur_data_valid, cur_targets_valid,
        #         optimizer,
        #         args.unrolled)
        hidden_valid, valid_loss = architect.step(hidden, cur_data,
                                                  cur_targets, hidden_valid,
                                                  cur_data_valid,
                                                  cur_targets_valid, optimizer,
                                                  args.unrolled)
        total_valid_loss += valid_loss.data
        # print('Finished architect step...')
        # assuming small_batch_size = batch_size so we don't accumulate gradients
        optimizer.zero_grad()
        # hidden[s_id] = repackage_hidden(hidden[s_id])
        #hidden = repackage_hidden(hidden)

        # log_prob, hidden[s_id], rnn_hs, dropped_rnn_hs = parallel_model(cur_data, hidden[s_id], return_h=True)
        # print('Entering model training...')
        hidden = torch.autograd.Variable(hidden.data)
        # Hidden should be all zeros
        print('hidden all zeros?: (not {})'.format(torch.sum(hidden)))
        log_prob, hidden, rnn_hs, dropped_rnn_hs = parallel_model(
            cur_data, hidden, return_h=True)
        # print('received predictions')
        raw_loss = nn.functional.nll_loss(log_prob, cur_targets)
        # print('received loss' )

        loss = raw_loss
        # Activiation Regularization
        if args.alpha > 0:
            loss = loss + sum(args.alpha * dropped_rnn_h.pow(2).mean()
                              for dropped_rnn_h in dropped_rnn_hs[-1:])
        # Temporal Activation Regularization (slowness)
        loss = loss + sum(args.beta * (rnn_h[1:] - rnn_h[:-1]).pow(2).mean()
                          for rnn_h in rnn_hs[-1:])
        # loss *= args.small_batch_size / args.batch_size
        total_loss += raw_loss.data  # * args.small_batch_size / args.batch_size
        loss.backward()

        # s_id += 1
        # start = end
        # end = start + args.small_batch_size
        # print('backpropogated...')
        gc.collect()
        # print('garbage collected...')

        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs.
        torch.nn.utils.clip_grad_norm(model.parameters(), args.clip)
        # print('clipped gradients...')
        optimizer.step()
        # print('updated gradients...')
        # total_loss += raw_loss.data
        optimizer.param_groups[0]['lr'] = lr2
        if batch % args.log_interval == 0:  # and batch > 0:
            logging.info(parallel_model.genotype())
            print(F.softmax(parallel_model.weights, dim=-1))
            #print('total loss: {}'.format(type(total_loss)))
            #print('total loss: {}'.format(total_loss))
            #print('total loss: {}'.format(total_loss.shape))
            #cur_loss = total_loss[0] / args.log_interval
            cur_loss = total_loss / args.log_interval
            cur_valid_loss = total_valid_loss / args.log_interval
            elapsed = time.time() - start_time
            logging.info(
                '| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | '
                'loss {:5.2f} | ppl {:8.2f} | valid loss: {:5.2f} | valid ppl: {:5.2f}'
                .format(epoch, batch, len(train_data),
                        optimizer.param_groups[0]['lr'],
                        elapsed * 1000 / args.log_interval, cur_loss,
                        math.exp(cur_loss), cur_valid_loss,
                        math.exp(cur_valid_loss)))
            total_loss = 0
            start_time = time.time()
        # print('on to next batch...')
        # batch += 1
        # i += seq_len
    print('Reached end of epoch training!')
Ejemplo n.º 3
0
def train():
    assert (
        args.batch_size % args.small_batch_size == 0
    ), "batch_size must be divisible by small_batch_size"

    # Turn on training mode which enables dropout.
    total_loss = 0
    start_time = time.time()
    ntokens = len(corpus.dictionary)
    hidden = [
        model.init_hidden(args.small_batch_size)
        for _ in range(args.batch_size // args.small_batch_size)
    ]
    hidden_valid = [
        model.init_hidden(args.small_batch_size)
        for _ in range(args.batch_size // args.small_batch_size)
    ]
    batch, i = 0, 0
    while i < train_data.size(0) - 1 - 1:
        bptt = args.bptt if np.random.random() < 0.95 else args.bptt / 2.0
        # Prevent excessively small or negative sequence lengths
        # seq_len = max(5, int(np.random.normal(bptt, 5)))
        # # There's a very small chance that it could select a very long sequence length resulting in OOM
        # seq_len = min(seq_len, args.bptt + args.max_seq_len_delta)
        seq_len = int(bptt)

        lr2 = optimizer.param_groups[0]["lr"]
        optimizer.param_groups[0]["lr"] = lr2 * seq_len / args.bptt
        model.train()

        data_valid, targets_valid = get_batch(
            search_data, i % (search_data.size(0) - 1), args
        )
        data, targets = get_batch(train_data, i, args, seq_len=seq_len)

        optimizer.zero_grad()

        start, end, s_id = 0, args.small_batch_size, 0
        while start < args.batch_size:
            # cur_data, cur_targets = (
            #     data[:, start:end],
            #     targets[:, start:end].contiguous().view(-1),
            # )
            # cur_data_valid, cur_targets_valid = (
            #     data_valid[:, start:end],
            #     targets_valid[:, start:end].contiguous(),
            # )
            cur_data, cur_targets = (data, targets.contiguous())
            cur_data_valid, cur_targets_valid = (
                data_valid, targets_valid.contiguous())

            # Starting each batch, we detach the hidden state from how it was previously produced.
            # If we didn't, the model would try backpropagating all the way to start of the dataset.
            hidden[s_id] = repackage_hidden(hidden[s_id])
            hidden_valid[s_id] = repackage_hidden(hidden_valid[s_id])

            hidden_valid[s_id], grad_norm = architect.step(
                hidden[s_id],
                cur_data,
                cur_targets,
                hidden_valid[s_id],
                cur_data_valid,
                cur_targets_valid,
                optimizer,
                args.unrolled,
            )

            # assuming small_batch_size = batch_size so we don't accumulate gradients
            optimizer.zero_grad()
            hidden[s_id] = repackage_hidden(hidden[s_id])

            log_prob, hidden[s_id], rnn_hs, dropped_rnn_hs = parallel_model(
                cur_data, hidden[s_id], return_h=True
            )
            raw_loss = nn.functional.nll_loss(
                log_prob.view(-1, log_prob.size(2)), cur_targets
            )

            loss = raw_loss
            # Activiation Regularization
            if args.alpha > 0:
                loss = loss + sum(
                    args.alpha * dropped_rnn_h.pow(2).mean()
                    for dropped_rnn_h in dropped_rnn_hs[-1:]
                )
            # Temporal Activation Regularization (slowness)
            loss = loss + sum(
                args.beta * (rnn_h[1:] - rnn_h[:-1]).pow(2).mean()
                for rnn_h in rnn_hs[-1:]
            )
            loss *= args.small_batch_size / args.batch_size
            total_loss += raw_loss.data * args.small_batch_size / args.batch_size
            loss.backward()

            s_id += 1
            start = end
            end = start + args.small_batch_size

            gc.collect()

        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs.
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
        optimizer.step()

        # total_loss += raw_loss.data
        optimizer.param_groups[0]["lr"] = lr2
        if batch % args.log_interval == 0 and batch > 0:
            logging.info(parallel_model.genotype())
            print(F.softmax(parallel_model.weights, dim=-1))
            cur_loss = total_loss.item() / args.log_interval
            elapsed = time.time() - start_time
            logging.info(
                "| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | "
                "loss {:5.2f} | ppl {:8.2f}".format(
                    epoch,
                    batch,
                    len(train_data) // args.bptt,
                    optimizer.param_groups[0]["lr"],
                    elapsed * 1000 / args.log_interval,
                    cur_loss,
                    math.exp(cur_loss),
                )
            )
            total_loss = 0
            start_time = time.time()
        batch += 1
        i += seq_len
Ejemplo n.º 4
0
def train():
    assert args.batch_size % args.small_batch_size == 0, 'batch_size must be divisible by small_batch_size'

    # Turn on training mode which enables dropout.
    total_loss = 0
    start_time = time.time()
    ntokens = len(corpus.dictionary)
    hidden = [
        model.init_hidden(args.small_batch_size)
        for _ in range(args.batch_size // args.small_batch_size)
    ]
    batch, i = 0, 0
    model.train()
    while i < train_data.size(0) - 1 - 1:
        bptt = args.bptt if np.random.random() < 0.95 else args.bptt / 2.
        # Prevent excessively small or negative sequence lengths
        # seq_len = max(5, int(np.random.normal(bptt, 5)))
        # # There's a very small chance that it could select a very long sequence length resulting in OOM
        # seq_len = min(seq_len, args.bptt + args.max_seq_len_delta)
        seq_len = int(bptt)

        lr2 = optimizer.param_groups[0]['lr']
        optimizer.param_groups[0]['lr'] = lr2 * seq_len / args.bptt

        data, targets = get_batch(train_data, i, args, seq_len=seq_len)

        start, end, s_id = 0, args.small_batch_size, 0
        while start < args.batch_size:
            cur_data, cur_targets = data[:,
                                         start:end], targets[:, start:
                                                             end].contiguous(
                                                             ).view(-1)

            optimizer.zero_grad()
            hidden[s_id] = repackage_hidden(hidden[s_id])

            parallel_model.sample_new_architecture()
            log_prob, hidden[s_id], rnn_hs, dropped_rnn_hs = parallel_model(
                cur_data, hidden[s_id], return_h=True)
            raw_loss = nn.functional.nll_loss(
                log_prob.view(-1, log_prob.size(2)), cur_targets)

            loss = raw_loss
            # Activiation Regularization
            if args.alpha > 0:
                loss = loss + sum(args.alpha * dropped_rnn_h.pow(2).mean()
                                  for dropped_rnn_h in dropped_rnn_hs[-1:])
            # Temporal Activation Regularization (slowness)
            loss = loss + sum(args.beta *
                              (rnn_h[1:] - rnn_h[:-1]).pow(2).mean()
                              for rnn_h in rnn_hs[-1:])
            loss *= args.small_batch_size / args.batch_size
            total_loss += raw_loss.data * args.small_batch_size / args.batch_size
            loss.backward()

            s_id += 1
            start = end
            end = start + args.small_batch_size

            gc.collect()

        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs.
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
        optimizer.step()

        # total_loss += raw_loss.data
        optimizer.param_groups[0]['lr'] = lr2
        if batch % args.log_interval == 0 and batch > 0:
            cur_loss = total_loss.item() / args.log_interval
            elapsed = time.time() - start_time
            logging.info(
                '| dag_epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | '
                'loss {:5.2f} | ppl {:8.2f}'.format(
                    epoch, batch,
                    len(train_data) // args.bptt,
                    optimizer.param_groups[0]['lr'],
                    elapsed * 1000 / args.log_interval, cur_loss,
                    math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()
        batch += 1
        i += seq_len