Пример #1
0
def main():
    args = parse_args()

    mp.set_start_method('spawn')  # Using spawn is decided.
    _logger = log.get_logger(__name__, args)
    _logger.info(print_args(args))

    loaders = []
    file_list = os.listdir(args.train_file)
    random.shuffle(file_list)
    for i in range(args.worker):
        loader = data_loader.DataLoader(args.train_file,
                                        args.dict_file,
                                        separate_conj_stmt=args.direction,
                                        binary=args.binary,
                                        part_no=i,
                                        part_total=args.worker,
                                        file_list=file_list,
                                        norename=args.norename,
                                        filter_abelian=args.fabelian,
                                        compatible=args.compatible)
        loaders.append(loader)
        loader.start_reader()

    cuda_test = torch.cuda.is_available()
    cuda_tensor = torch.randn(10).cuda()

    net, mid_net, loss_fn = create_models(args, loaders[0], allow_resume=True)
    # Use fake modules to replace the real ones
    net = FakeModule(net)
    if mid_net is not None:
        mid_net = FakeModule(mid_net)
    for i in range(len(loss_fn)):
        loss_fn[i] = FakeModule(loss_fn[i])
    opt = get_opt(net, mid_net, loss_fn, args)

    inqueues = []
    outqueues = []

    plist = []
    for i in range(args.worker):
        recv_p, send_p = Pipe(False)
        recv_p2, send_p2 = Pipe(False)
        inqueues.append(send_p)
        outqueues.append(recv_p2)
        plist.append(
            Process(target=worker,
                    args=(recv_p, send_p2, loaders[i], args, i)))
        plist[-1].start()

    _logger.warning('Training begins')
    train(inqueues, outqueues, net, mid_net, loss_fn, opt, loaders, args,
          _logger)
    loader.destruct()
    for p in plist:
        p.terminate()
    for loader in loaders:
        loader.destruct()
    _logger.warning('Training ends')
Пример #2
0
    jointmodel = Janggu(dnamodel.kerasmodel.inputs +
                        dnasemodel.kerasmodel.inputs,
                        output,
                        name='pretrained_dnase_dna_joint_model_{}_{}'.format(
                            dnasename, dnaname))

    # reload the same model architecture, but this will
    # randomly reinitialized the weights
    newjointmodel = model_from_json(jointmodel.kerasmodel.to_json())

    newjointmodel = Janggu(
        newjointmodel.inputs,
        newjointmodel.outputs,
        name='randominit_dnase_dna_joint_model_{}_{}'.format(
            dnasename, dnaname))
    newjointmodel.compile(optimizer=get_opt('amsgrad'),
                          loss='binary_crossentropy',
                          metrics=['acc'])

    hist = newjointmodel.fit(
        train_data[0],
        train_data[1],
        epochs=shared_space['epochs'],
        batch_size=64,
        validation_data=val_data,
        callbacks=[EarlyStopping(patience=5, restore_best_weights=True)])

    pred_test = newjointmodel.predict(test_data[0])
    pred_val = newjointmodel.predict(val_data[0])

    auprc_val = average_precision_score(val_data[1][:], pred_val)
Пример #3
0
def train(inqueues, outqueues, net, mid_net, loss_fn, opt, loaders, args,
          _logger):
    def _update_grad(net, mid_net, loss_fn, grad_list):
        if not args.fix_net:
            for name, param in net.named_parameters():
                grad_step = sum([
                    data['total'] * data['net'][name] for data in grad_list
                ]) / sum([data['total'] for data in grad_list])
                param.grad = Variable(grad_step)

            if mid_net is not None:
                for name, param in mid_net.named_parameters():
                    grad_step = sum([
                        data['total'] * data['mid_net'][name]
                        for data in grad_list
                    ]) / sum([data['total'] for data in grad_list])
                    param.grad = Variable(grad_step)

        for i in range(args.loss_step):
            for name, param in loss_fn[i].named_parameters():
                grad_step = sum([
                    data['total'] * data['loss_fn'][i][name]
                    for data in grad_list
                ]) / sum([data['total'] for data in grad_list])
                param.grad = Variable(grad_step)

            for name in loss_fn[i].buffers.keys():
                loss_fn[i].buffers[name].copy_(
                    grad_list[0]['buffer']['loss_fn'][i][name])

    def _update_correct(grad_list, corrects):
        for i in range(len(corrects)):
            corrects[i] += sum([data['correct'][i] for data in grad_list])

    def _step_stat(grad_list):
        step_sample_total = sum([data['total'] for data in grad_list])
        step_loss_total = sum([data['loss_total'] for data in grad_list])
        return step_sample_total, step_loss_total

    # Variables for training
    t = time.time()
    cur_epoch_training_total = 0
    training_total = 0
    valid_total = 0
    sample_total = 0
    loss_total = 0
    correct_total = [0] * (args.loss_step)

    recorder = None
    if args.record is not None:
        recorder = HistoryRecorder(args.record)

    epoch = 0
    epoch_end = False
    while epoch < args.epoch:
        data = {}

        if not args.fix_net or training_total == 0:
            data['fix_net'] = False
        else:
            data['fix_net'] = True

        data['args'] = args
        if not data['fix_net']:
            data['net'] = net.state_dict()
            if mid_net is not None:
                data['mid_net'] = mid_net.state_dict()

        data['loss_fn'] = []
        for loss in loss_fn:
            data['loss_fn'].append(loss.state_dict())
        data['test'] = False

        for i in range(args.worker):
            inqueues[i].send(data)

        grad_list = []
        for i in range(args.worker):
            data = outqueues[i].recv()
            grad_list.append(data)

        _update_grad(net, mid_net, loss_fn, grad_list)
        _update_correct(grad_list, correct_total)

        step_sample_total, step_loss_total = _step_stat(grad_list)

        cur_epoch_training_total += step_sample_total
        training_total += step_sample_total
        valid_total += step_sample_total
        sample_total += step_sample_total
        loss_total += step_loss_total
        opt.step()
        if (epoch + 1) * args.epoch_len <= training_total:
            _logger.info('Epoch END!!!')
            epoch_end = True

        if sample_total > args.observe or epoch_end:
            end_str = ' END!' if epoch_end else ''
            _logger.info('Epoch: %d%s Iteration: %d Loss: %.5f perTime: %.3f',
                         epoch, end_str, cur_epoch_training_total,
                         loss_total / sample_total,
                         (time.time() - t) / sample_total)
            accs = []
            for k in range(len(loss_fn)):
                accs.append('acc %d: %.5f' %
                            (k, correct_total[k] / sample_total))
            if recorder is not None:
                recorder.train_acc(training_total,
                                   correct_total[-1] / sample_total)
                recorder.save_record()
            _logger.info(' '.join(accs))
            sample_total = 0
            loss_total = 0
            correct_total = [0] * (args.loss_step)
            t = time.time()

        if valid_total > args.check_num or (epoch_end
                                            and epoch == args.epoch - 1):
            aux = {}
            aux['epoch'] = epoch
            aux['cur_iter'] = cur_epoch_training_total
            aux['total_iter'] = training_total
            save_model(
                aux, args, net, mid_net, loss_fn,
                args.output + '_%d_%d' % (epoch, cur_epoch_training_total))
            _logger.warning(
                'Model saved to %s',
                args.output + '_%d_%d' % (epoch, cur_epoch_training_total))
            _logger.warning('Start validation!')
            valid_start = time.time()
            data = {}
            data['fix_net'] = False
            data['args'] = args
            data['net'] = net.state_dict()
            if mid_net is not None:
                data['mid_net'] = mid_net.state_dict()
            data['loss_fn'] = []
            for loss in loss_fn:
                data['loss_fn'].append(loss.state_dict())
            data['test'] = True

            for i in range(args.worker):
                inqueues[i].send(data)

            result_correct = 0
            result_total = 0

            for i in range(args.worker):
                data = outqueues[i].recv()
                result_correct += data['correct']
                result_total += data['total']
            result_ = result_correct / result_total
            _logger.warning(
                'Validation complete! Time lapse: %.3f, Test acc: %.5f' %
                (time.time() - valid_start, result_))

            if recorder is not None:
                recorder.test_acc(training_total, result_)
                recorder.save_record()
            valid_total = 0
            if args.fix_net:
                _logger.warning('learning rate decreases from %.6f to %.6f',
                                args.learning_rate, args.learning_rate / 3)
                args.learning_rate /= 2
                opt = get_opt(net, mid_net, loss_fn, args)

        if args.unfix_net_after is not None and training_total > args.unfix_net_after:
            args.fix_net = False

        if epoch_end and args.learning_rate > args.min_lr:
            _logger.warning('learning rate decreases from %.6f to %.6f',
                            args.learning_rate, args.learning_rate / 3)
            args.learning_rate /= 3
            opt = get_opt(net, mid_net, loss_fn, args)

        if epoch_end:
            cur_epoch_training_total = 0
            epoch_end = False
            epoch += 1