Пример #1
0
tagger = RecurrentTagger(
    embedding=embedding,
    bio_embed_size=config['bio_embed_size'],
    hidden_size=config['hidden_size'],
    num_layers=config['num_layers'],
    dropout=config['dropout'],
    bidirectional=config['bidirectional']
)

tagger = tagger.cuda()

criterion = nn.CrossEntropyLoss(ignore_index=-1)
# criterion = FocalLoss(gamma=1.5, ignore_index=-1)
# criterion = BinaryFocalLoss(alpha=1, gamma=2, ignore_index=-1)

optimizer = optim.Adam(tagger.parameters(), lr=config['learning_rate'], weight_decay=config['l2_reg'])

visualizer = Visualizer(plot_path)

max_val_f1_score = 0

for epoch in range(config['num_epoches']):
    preds_collection = []
    labels_collection = []
    total_loss = 0
    total_samples = 0
    for i, data in enumerate(train_loader):
        tagger.train()
        optimizer.zero_grad()
        sentences, targets, labels = data
        sentences, targets, labels = sentences.cuda(), targets.cuda(), labels.cuda()
Пример #2
0
                                                 [transforms.ToTensor()])),
                              batch_size=batch_size,
                              shuffle=True,
                              drop_last=True,
                              **kwargs)

    test_loader = DataLoader(datasets.MNIST('../data',
                                            train=False,
                                            transform=transforms.Compose(
                                                [transforms.ToTensor()])),
                             batch_size=test_batch_size,
                             drop_last=True,
                             shuffle=False)

    net = MNIST(noise='bernoulli').to(device)
    optimizer = optim.Adam(net.parameters(), lr=0.0003, weight_decay=1e-6)
    loss = nn.CrossEntropyLoss()

    rloss, tloss = 0, 0
    ac_, er_ = [], []
    for e in range(epochs):
        rloss = train(net, device, train_loader, optimizer, loss, e)
        if ((e + 1) % 50) == 0:
            optimizer.param_groups[-1]['lr'] /= 2
        print("Epoch")
        if (e % 2) == 0:
            ac, er = sampling_test(net,
                                   device,
                                   loss,
                                   test_loader,
                                   batch_size=test_batch_size,
Пример #3
0
    def run(self, model, input, target, batch_idx=0, batch_size=None, input_token=None):
        if batch_size is None:
            batch_size = input.size(1)  # ([length, batch_size, nhim])
        # set the lower and upper bounds accordingly
        lower_bound = np.zeros(batch_size)
        scale_const = np.ones(batch_size) * self.initial_const
        upper_bound = np.ones(batch_size) * 1e10

        # python/numpy placeholders for the overall best l2, label score, and adversarial image
        o_best_l2 = [1e10] * batch_size
        o_best_score = [-1] * batch_size
        if input_token is None:
            best_attack = input.cpu().detach().numpy()
            o_best_attack = input.cpu().detach().numpy()
        else:
            best_attack = input_token.cpu().detach().numpy()
            o_best_attack = input_token.cpu().detach().numpy()
        self.o_best_sent = {}
        self.best_sent = {}

        # setup input (image) variable, clamp/scale as necessary
        input_var = torch.tensor(input, requires_grad=False)
        self.itereated_var = torch.tensor(input_var)
        # setup the target variable, we need it to be in one-hot form for the loss function
        target_onehot = torch.zeros(target.size() + (self.num_classes,))
        if self.cuda:
            target_onehot = target_onehot.cuda()
        target_onehot.scatter_(1, target.unsqueeze(1), 1.)
        target_var = torch.tensor(target_onehot, requires_grad=False)

        # setup the modifier variable, this is the variable we are optimizing over
        modifier = torch.zeros(input_var.size()).float().cuda()
        if self.cuda:
            modifier = modifier.cuda()
        modifier_var = torch.tensor(modifier, requires_grad=True)

        optimizer = optim.Adam([modifier_var], lr=args.lr)

        for search_step in range(self.binary_search_steps):
            if args.debugging:
                print('Batch: {0:>3}, search step: {1}'.format(batch_idx, search_step))
            if self.debug:
                print('Const:')
                for i, x in enumerate(scale_const):
                    print(i, x)
            best_l2 = [1e10] * batch_size
            best_score = [-1] * batch_size
            # The last iteration (if we run many steps) repeat the search once.
            if self.repeat and search_step == self.binary_search_steps - 1:
                scale_const = upper_bound

            scale_const_tensor = torch.from_numpy(scale_const).float()
            if self.cuda:
                scale_const_tensor = scale_const_tensor.cuda()
            scale_const_var = torch.tensor(scale_const_tensor, requires_grad=False)

            for step in range(self.max_steps):
                # perform the attack
                if self.mask is None:
                    if args.decreasing_temp:
                        cur_temp = args.temp - (args.temp - 0.1) / (self.max_steps - 1) * step
                        model.set_temp(cur_temp)
                        if args.debugging:
                            print("temp:", cur_temp)
                    else:
                        model.set_temp(args.temp)

                loss, dist, output, adv_img, adv_sents = self._optimize(
                    optimizer,
                    model,
                    input_var,
                    modifier_var,
                    target_var,
                    scale_const_var,
                    input_token)

                # if step % 100 == 0 or step == self.max_steps - 1:
                # if args.debugging:
                #     print('Step: {0:>4}, loss: {1:6.4f}, dist: {2:8.5f}, modifier mean: {3:.5e}'.format(
                #         step, loss, dist.mean(), modifier_var.data.mean()))

                # if self.abort_early and step % (self.max_steps // 10) == 0:
                #     if loss > prev_loss * .9999:
                #         print('Aborting early...', loss , '>', prev_loss)
                #         break
                #     prev_loss = loss

                # update best result found
                for i in range(batch_size):
                    target_label = target[i]
                    output_logits = output[i]
                    output_label = np.argmax(output_logits)
                    di = dist[i]
                    if self.debug:
                        if step % 100 == 0:
                            print('{0:>2} dist: {1:.5f}, output: {2:>3}, {3:5.3}, target {4:>3}'.format(
                                i, di, output_label, output_logits[output_label], target_label))
                    if di < best_l2[i] and self._compare_untargeted(output_logits, target_label):
                        # if self._compare(output_logits, target_label):
                        if self.debug:
                            print('{0:>2} best step,  prev dist: {1:.5f}, new dist: {2:.5f}'.format(
                                i, best_l2[i], di))
                        best_l2[i] = di
                        best_score[i] = output_label
                        best_attack[:, i] = adv_img[:, i]
                        self.best_sent[i] = adv_sents[i]
                    if di < o_best_l2[i] and self._compare(output_logits, target_label):
                        # if self._compare(output_logits, target_label):
                        if self.debug:
                            print('{0:>2} best total, prev dist: {1:.5f}, new dist: {2:.5f}'.format(
                                i, o_best_l2[i], di))
                        o_best_l2[i] = di
                        o_best_score[i] = output_label
                        o_best_attack[:, i] = adv_img[:, i]
                        self.o_best_sent[i] = adv_sents[i]
                sys.stdout.flush()
                # end inner step loop

            # adjust the constants
            batch_failure = 0
            batch_success = 0
            for i in range(batch_size):
                # if self._compare(best_score[i], target[i]) and best_score[i] != -1:
                #     # successful, do binary search and divide const by two
                #     upper_bound[i] = min(upper_bound[i], scale_const[i])
                #     if upper_bound[i] < 1e9:
                #         scale_const[i] = (lower_bound[i] + upper_bound[i]) / 2
                #     if self.debug:
                #         print('{0:>2} successful attack, lowering const to {1:.3f}'.format(
                #             i, scale_const[i]))
                # else:
                #     # failure, multiply by 10 if no solution found
                #     # or do binary search with the known upper bound
                #     lower_bound[i] = max(lower_bound[i], scale_const[i])
                #     if upper_bound[i] < 1e9:
                #         scale_const[i] = (lower_bound[i] + upper_bound[i]) / 2
                #     else:
                #         scale_const[i] *= 10
                #     if self.debug:
                #         print('{0:>2} failed attack, raising const to {1:.3f}'.format(
                #             i, scale_const[i]))
                if self._compare(o_best_score[i], target[i]) and o_best_score[i] != -1:
                    batch_success += 1
                elif self._compare_untargeted(best_score[i], target[i]) and best_score[i] != -1:
                    o_best_l2[i] = best_l2[i]
                    o_best_score[i] = best_score[i]
                    o_best_attack[:, i] = best_attack[:, i]
                    self.o_best_sent[i] = self.best_sent[i]
                    batch_success += 1
                else:
                    batch_failure += 1

            logger.info('Num failures: {0:2d}, num successes: {1:2d}\n'.format(batch_failure, batch_success))
            sys.stdout.flush()
            # end outer search loop

        return o_best_attack
Пример #4
0
def run(args):
    global logging
    logging = create_exp_dir(args.exp_dir, scripts_to_save=[])

    if args.cuda:
        logging('using cuda')
    logging(str(args))

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        torch.backends.cudnn.deterministic = True

    opt_dict = {"not_improved": 0, "lr": 1., "best_loss": 1e4}

    vocab = {}
    with open(args.vocab_file) as fvocab:
        for i, line in enumerate(fvocab):
            vocab[line.strip()] = i

    vocab = VocabEntry(vocab)

    train_data = MonoTextData(args.train_data, label=args.label, vocab=vocab)

    vocab_size = len(vocab)

    val_data = MonoTextData(args.val_data, label=args.label, vocab=vocab)
    test_data = MonoTextData(args.test_data, label=args.label, vocab=vocab)

    logging('Train data: %d samples' % len(train_data))
    logging('finish reading datasets, vocab size is %d' % len(vocab))
    logging('dropped sentences: %d' % train_data.dropped)
    #sys.stdout.flush()

    log_niter = max(1, (len(train_data)//(args.batch_size * args.update_every))//10)

    model_init = uniform_initializer(0.01)
    emb_init = uniform_initializer(0.1)

    #device = torch.device("cuda" if args.cuda else "cpu")
    device = "cuda" if args.cuda else "cpu"
    args.device = device

    if args.enc_type == 'lstm':
        encoder = GaussianLSTMEncoder(args, vocab_size, model_init, emb_init)
        args.enc_nh = args.dec_nh
    else:
        raise ValueError("the specified encoder type is not supported")

    decoder = LSTMDecoder(args, vocab, model_init, emb_init)
    vae = VAE(encoder, decoder, args).to(device)

    if args.load_path:
        loaded_state_dict = torch.load(args.load_path)
        #curr_state_dict = vae.state_dict()
        #curr_state_dict.update(loaded_state_dict)
        vae.load_state_dict(loaded_state_dict)
        logging("%s loaded" % args.load_path)

    # if args.eval:
    #     logging('begin evaluation')
    #     vae.load_state_dict(torch.load(args.load_path))
    #     vae.eval()
    #     with torch.no_grad():
    #         test_data_batch = test_data.create_data_batch(batch_size=args.batch_size,
    #                                                       device=device,
    #                                                       batch_first=True)

    #         test(vae, test_data_batch, test_labels_batch, "TEST", args)
    #         au, au_var = calc_au(vae, test_data_batch)
    #         logging("%d active units" % au)
    #         # print(au_var)

    #         test_data_batch = test_data.create_data_batch(batch_size=1,
    #                                                       device=device,
    #                                                       batch_first=True)
    #         calc_iwnll(vae, test_data_batch, args)

    #     return

    if args.discriminator == "linear":
        discriminator = LinearDiscriminator(args, vae.encoder).to(device)
    elif args.discriminator == "mlp":
        discriminator = MLPDiscriminator(args, vae.encoder).to(device)

    if args.opt == "sgd":
        optimizer = optim.SGD(discriminator.parameters(), lr=args.lr, momentum=args.momentum)
        opt_dict['lr'] = args.lr
    elif args.opt == "adam":
        optimizer = optim.Adam(discriminator.parameters(), lr=0.001)
        opt_dict['lr'] = 0.001
    else:
        raise ValueError("optimizer not supported")

    iter_ = decay_cnt = 0
    best_loss = 1e4
    # best_kl = best_nll = best_ppl = 0
    # pre_mi = 0
    discriminator.train()
    start = time.time()

    # kl_weight = args.kl_start
    # if args.warm_up > 0:
    #     anneal_rate = (1.0 - args.kl_start) / (args.warm_up * (len(train_data) / args.batch_size))
    # else:
    #     anneal_rate = 0

    # dim_target_kl = args.target_kl / float(args.nz)

    train_data_batch, train_labels_batch = train_data.create_data_batch_labels(batch_size=args.batch_size,
                                                    device=device,
                                                    batch_first=True)

    val_data_batch, val_labels_batch = val_data.create_data_batch_labels(batch_size=128,
                                                device=device,
                                                batch_first=True)

    test_data_batch, test_labels_batch = test_data.create_data_batch_labels(batch_size=128,
                                                  device=device,
                                                  batch_first=True)

    acc_cnt = 1
    acc_loss = 0.
    for epoch in range(args.epochs):
        report_loss = 0
        report_correct = report_num_words = report_num_sents = 0
        acc_batch_size = 0
        optimizer.zero_grad()
        for i in np.random.permutation(len(train_data_batch)):

            batch_data = train_data_batch[i]
            if batch_data.size(0) < 2:
                continue
            batch_labels = train_labels_batch[i]
            batch_labels = [int(x) for x in batch_labels]

            batch_labels = torch.tensor(batch_labels, dtype=torch.long, requires_grad=False, device=device)

            batch_size, sent_len = batch_data.size()

            # not predict start symbol
            report_num_words += (sent_len - 1) * batch_size
            report_num_sents += batch_size
            acc_batch_size += batch_size

            # (batch_size)
            loss, correct = discriminator.get_performance(batch_data, batch_labels)

            acc_loss = acc_loss + loss.sum()

            if acc_cnt % args.update_every == 0:
                acc_loss = acc_loss / acc_batch_size
                acc_loss.backward()

                torch.nn.utils.clip_grad_norm_(discriminator.parameters(), clip_grad)

                optimizer.step()
                optimizer.zero_grad()

                acc_cnt = 0
                acc_loss = 0
                acc_batch_size = 0

            acc_cnt += 1
            report_loss += loss.sum().item()
            report_correct += correct

            if iter_ % log_niter == 0:
                #train_loss = (report_rec_loss  + report_kl_loss) / report_num_sents
                train_loss = report_loss / report_num_sents

        
                logging('epoch: %d, iter: %d, avg_loss: %.4f, acc %.4f,' \
                       'time %.2fs' %
                       (epoch, iter_, train_loss, report_correct / report_num_sents,
                        time.time() - start))

                #sys.stdout.flush()

            iter_ += 1

        logging('lr {}'.format(opt_dict["lr"]))

        discriminator.eval()

        with torch.no_grad():
            loss, acc = test(discriminator, val_data_batch, val_labels_batch, "VAL", args)
            # print(au_var)

        if loss < best_loss:
            logging('update best loss')
            best_loss = loss
            best_acc = acc
            print(args.save_path)
            torch.save(discriminator.state_dict(), args.save_path)

        if loss > opt_dict["best_loss"]:
            opt_dict["not_improved"] += 1
            if opt_dict["not_improved"] >= decay_epoch and epoch >= args.load_best_epoch:
                opt_dict["best_loss"] = loss
                opt_dict["not_improved"] = 0
                opt_dict["lr"] = opt_dict["lr"] * lr_decay
                discriminator.load_state_dict(torch.load(args.save_path))
                logging('new lr: %f' % opt_dict["lr"])
                decay_cnt += 1
                if args.opt == "sgd":
                    optimizer = optim.SGD(discriminator.parameters(), lr=opt_dict["lr"], momentum=args.momentum)
                    opt_dict['lr'] = opt_dict["lr"]
                elif args.opt == "adam":
                    optimizer = optim.Adam(discriminator.parameters(), lr=opt_dict["lr"])
                    opt_dict['lr'] = opt_dict["lr"]
                else:
                    raise ValueError("optimizer not supported")                

        else:
            opt_dict["not_improved"] = 0
            opt_dict["best_loss"] = loss

        if decay_cnt == max_decay:
            break

        if epoch % args.test_nepoch == 0:
            with torch.no_grad():
                loss, acc = test(discriminator, test_data_batch, test_labels_batch, "TEST", args)

        discriminator.train()


    # compute importance weighted estimate of log p(x)
    discriminator.load_state_dict(torch.load(args.save_path))
    discriminator.eval()

    with torch.no_grad():
        loss, acc = test(discriminator, test_data_batch, test_labels_batch, "TEST", args)
        # print(au_var)
    
    return acc
Пример #5
0
        x = self.fc3(x)
        return x



net = Net()

########################################################################
# 3. Define a Loss function and optimizer
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Let's use a Classification Cross-Entropy loss and SGD with momentum

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0001)

########################################################################
# 4. Train the network
# ^^^^^^^^^^^^^^^^^^^^
#
# This is when things start to get interesting.
# We simply have to loop over our data iterator, and feed the inputs to the
# network and optimize

for epoch in range(30):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs
        inputs, labels = data
Пример #6
0
def main():
    print("SSSRNet3 training from scratch on VOC 160*160 patches.")
    global opt, model, netContent
    opt = parser.parse_args()
    print(opt)
    gpuid = 0
    cuda = True
    if cuda and not torch.cuda.is_available():
        raise Exception("No GPU found, please run without --cuda")

    opt.seed = random.randint(1, 10000)
    print("Random Seed: ", opt.seed)
    torch.manual_seed(opt.seed)
    if cuda:
        torch.cuda.manual_seed(opt.seed)

    cudnn.benchmark = True

    if opt.vgg_loss:
        print('===> Loading VGG model')
        netVGG = models.vgg19()
        netVGG.load_state_dict(
            model_zoo.load_url(
                'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth'))

        class _content_model(nn.Module):
            def __init__(self):
                super(_content_model, self).__init__()
                self.feature = nn.Sequential(
                    *list(netVGG.features.children())[:-1])

            def forward(self, x):
                out = self.feature(x)
                return out

        netContent = _content_model()

    print("===> Building model")
    model = Net()
    print('Parameters: {}'.format(get_n_params(model)))
    model_pretrained = torch.load(
        'model/model_DIV2K_noBN_96_epoch_36.pth',
        map_location=lambda storage, loc: storage)["model"]
    finetune = False

    if finetune == True:

        index = 0
        for (src, dst) in zip(model_pretrained.parameters(),
                              model.parameters()):
            if index > 1:
                list(model.parameters())[index].data = src.data
            index = index + 1

    criterion = nn.MSELoss(size_average=False)

    print("===> Setting GPU")
    if cuda:
        model = model.cuda(gpuid)
        model_pretrained = model_pretrained.cuda(gpuid)
        criterion = criterion.cuda(gpuid)
        if opt.vgg_loss:
            netContent = netContent.cuda(gpuid)

            # optionally resume from a checkpoint
    if opt.resume:
        if os.path.isfile(opt.resume):
            print("=> loading checkpoint '{}'".format(opt.resume))
            checkpoint = torch.load(opt.resume)
            opt.start_epoch = checkpoint["epoch"] + 1
            model.load_state_dict(checkpoint["model"].state_dict())
        else:
            print("=> no checkpoint found at '{}'".format(opt.resume))

    # optionally copy weights from a checkpoint
    if opt.pretrained:
        if os.path.isfile(opt.pretrained):
            print("=> loading model '{}'".format(opt.pretrained))
            weights = torch.load(opt.pretrained)
            model.load_state_dict(weights['model'].state_dict())
        else:
            print("=> no model found at '{}'".format(opt.pretrained))

    print("===> Setting Optimizer")
    optimizer = optim.Adam(model.parameters(), lr=opt.lr)

    print("===> Training1")
    #root_dir = '/tmp4/hang_data/DIV2K/DIV2K_train_320_HDF5'
    root_dir = '/tmp4/hang_data/VOCdevkit/VOC2012/VOC_train_label160_HDF5'
    files_num = len(os.listdir(root_dir))
    for epoch in range(opt.start_epoch, opt.nEpochs + 1):
        #save_checkpoint(model, epoch)
        print("===> Loading datasets")
        x = random.sample(os.listdir(root_dir), files_num)
        for index in range(0, files_num):
            train_path = os.path.join(root_dir, x[index])
            print("===> Training datasets: '{}'".format(train_path))
            train_set = DatasetFromHdf5(train_path)
            training_data_loader = DataLoader(dataset=train_set,
                                              num_workers=opt.threads,
                                              batch_size=opt.batchSize,
                                              shuffle=True)
            avgloss = train(training_data_loader, optimizer, model,
                            model_pretrained, criterion, epoch, gpuid)
        if epoch % 2 == 0:
            save_checkpoint(model, epoch)
Пример #7
0
            input_nc=3, n_layers=6,
            norm_layer=torch.nn.InstanceNorm2d).to(device)
    else:
        D = MultiscaleDiscriminator(
            input_nc=3, n_layers=7,
            norm_layer=torch.nn.InstanceNorm2d).to(device)
    G.train()
    D.train()

    arcface = Backbone(50, 0.6, 'ir_se').to(device)
    arcface.eval()
    arcface.load_state_dict(torch.load('./id_model/model_ir_se50.pth',
                                       map_location=device),
                            strict=False)

    opt_G = optim.Adam(G.parameters(), lr=lr_G, betas=(0, 0.999))
    opt_D = optim.Adam(D.parameters(), lr=lr_D, betas=(0, 0.999))

    G, opt_G = amp.initialize(G, opt_G, opt_level=optim_level)
    D, opt_D = amp.initialize(D, opt_D, opt_level=optim_level)

    try:
        if Flag_256:
            Finetune_G_Model = './saved_models/G_latest.pth'
            Finetune_D_Model = './saved_models/D_latest.pth'
        else:
            Finetune_G_Model = './saved_models/G_latest_512.pth'
            Finetune_D_Model = './saved_models/D_latest_512.pth'

        G.load_state_dict(torch.load(Finetune_G_Model,
                                     map_location=torch.device('cpu')),
  def __init__(self, config):
    self.config = config

    self.cuda = torch.cuda.is_available()
    self.epoch = config.epoch
    self.batch_size = config.batch_size
    self.beta = config.beta
    self.lr = config.lr
    self.n_layers = config.get('num_layers', 1)
    self.weight_phone_loss = config.get('weight_phone_loss', 1.)
    self.weight_word_loss = config.get('weight_word_loss', 1.)
    self.anneal_rate = config.get('anneal_rate', 3e-6)
    self.num_sample = config.get('num_sample', 1)
    self.eps = 1e-9
    self.max_grad_norm = config.get('max_grad_norm', None)
    if config.audio_feature == 'mfcc':
      self.audio_feature_net = None
      self.input_size = 80
      self.hop_len_ms = 10

    elif config.audio_feature == 'wav2vec2':
      self.audio_feature_net = cuda(fairseq.checkpoint_utils.load_model_ensemble_and_task([config.wav2vec_path])[0][0],
                                    self.cuda)
      for p in self.audio_feature_net.parameters():
        p.requires_grad = False
      self.input_size = 512
      self.hop_len_ms = 20 
    else:
      raise ValueError(f"Feature type {config.audio_feature} not supported")
   
    self.K = config.K
    self.global_iter = 0
    self.global_epoch = 0
    self.audio_feature = config.audio_feature
    self.image_feature = config.image_feature
    self.debug = config.debug
    self.dataset = config.dataset

    # Dataset
    self.data_loader = return_data(config)
    self.n_visual_class = self.data_loader['train']\
                          .dataset.preprocessor.num_visual_words
    self.n_phone_class = self.data_loader['train']\
                         .dataset.preprocessor.num_tokens
    self.visual_words = self.data_loader['train']\
                        .dataset.preprocessor.visual_words 
    print(f'Number of visual label classes = {self.n_visual_class}')
    print(f'Number of phone classes = {self.n_phone_class}')
  
    self.model_type = config.model_type 
    if config.model_type == 'blstm':
      self.audio_net = cuda(GumbelBLSTM(
                              self.K,
                              input_size=self.input_size,
                              n_layers=self.n_layers,
                              n_class=self.n_visual_class,
                              n_gumbel_units=self.n_phone_class,
                              ds_ratio=1,
                              bidirectional=True), self.cuda)
      self.K = 2 * self.K
    elif config.model_type == 'mlp':
      self.audio_net = cuda(GumbelMLP(
                                self.K,
                                input_size=self.input_size,
                                n_class=self.n_visual_class,
                                n_gumbel_units=self.n_phone_class,
                            ), self.cuda)
    elif config.model_type == 'tds':
      self.audio_net = cuda(GumbelTDS(
                              input_size=self.input_size,
                              n_class=self.n_visual_class,
                              n_gumbel_units=self.n_phone_class,
                            ), self.cuda)
    elif config.model_type == 'vq-mlp':
      self.audio_net = cuda(VQMLP(
                              self.K,
                              input_size=self.input_size,
                              n_class=self.n_visual_class,
                              n_embeddings=self.n_phone_class
                            ), self.cuda) 
  
    self.phone_net = cuda(GumbelBLSTM(
                            self.K,
                            input_size=self.n_phone_class,
                            n_layers=self.n_layers,
                            n_class=self.n_visual_class,
                            n_gumbel_units=self.n_phone_class,
                            ds_ratio=1,
                            bidirectional=True), self.cuda)

    trainables = [p for p in self.audio_net.parameters()]
    optim_type = config.get('optim', 'adam')
    if optim_type == 'sgd':
      self.optim = optim.SGD(trainables, lr=self.lr)
    else:
      self.optim = optim.Adam(trainables,
                              lr=self.lr, betas=(0.5,0.999))
    self.scheduler = lr_scheduler.ExponentialLR(self.optim, gamma=0.97)
    self.ckpt_dir = Path(config.ckpt_dir)
    if not self.ckpt_dir.exists(): 
      self.ckpt_dir.mkdir(parents=True, exist_ok=True)
    self.load_ckpt = config.load_ckpt
    if self.load_ckpt or config.mode == 'test': 
      self.load_checkpoint()
    
    # History
    self.history = dict()
    self.history['acc']=0. 
    self.history['token_f1']=0.
    self.history['loss']=0.
    self.history['epoch']=0
    self.history['iter']=0
Пример #9
0
    def forward(self, x):
        x = F.relu(self.affine1(x))
        # add second action_score for second environment
        action1_scores = self.action1_head(x)
        action2_scores = self.action2_head(x)
        state_values = self.value_head(x)
        # converts action scores to probabilities
        return F.softmax(action1_scores, dim=-1), \
               F.softmax(action2_scores, dim=-1),  \
               state_values

#instantiate a new object from the policy class - keep single policy for both
model = Policy()
#choose the optimizer to use; lr is the learning rate
optimizer = optim.Adam(model.parameters(), lr=3e-2)
eps = np.finfo(np.float32).eps.item()

def select_action(state, env):
    '''given a state, this function chooses the action to take
    arguments: state - observation matrix specifying the current model state
               env - integer specifying which environment to sample action
                         for
    return - action to take'''
    state = torch.from_numpy(state).float()
    # retrain the model
    probs1, probs2, state_value = model(state)
    #select which probability to use based on passed argument
    if env == 1:
        probs = probs1
    if env == 2:
Пример #10
0
netParameter['n_hidden'] = 128
netParameter['n_output'] = N_A

actorNet = ActorConvNet(netParameter['n_feature'],
                                    netParameter['n_hidden'],
                                    netParameter['n_output'])

actorTargetNet = deepcopy(actorNet)

criticNet = CriticConvNet(netParameter['n_feature'] ,
                            netParameter['n_hidden'],
                        netParameter['n_output'])

criticTargetNet = deepcopy(criticNet)

actorOptimizer = optim.Adam(actorNet.parameters(), lr=config['actorLearningRate'])
criticOptimizer = optim.Adam(criticNet.parameters(), lr=config['criticLearningRate'])

actorNets = {'actor': actorNet, 'target': actorTargetNet}
criticNets = {'critic': criticNet, 'target': criticTargetNet}
optimizers = {'actor': actorOptimizer, 'critic':criticOptimizer}
agent = DDPGAgent(config, actorNets, criticNets, env, optimizers, torch.nn.MSELoss(reduction='mean'), N_A, stateProcessor=stateProcessor, experienceProcessor=experienceProcessor)


plotPolicyFlag = False
N = 100
if plotPolicyFlag:
    phi = 0.0

    xSet = np.linspace(-10,10,N)
    ySet = np.linspace(-10,10,N)
val_data = pickle.load(open("validation.p", "rb"))
test_data = pickle.load(open("test.p", "rb"))

train_loader_unlabeled = torch.utils.data.DataLoader(train_unlabeled, batch_size=64, shuffle=True)
train_loader_labeled = torch.utils.data.DataLoader(train_labeled, batch_size=64, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=64, shuffle=False)

if len(sys.argv) > 1:
    unsup_cost = float(sys.argv[1])
    m_cost = float(sys.argv[1])

# model = DCVAE2_Pool_Deeper(cost_rec = unsup_cost)
model = DCVAE2_Pool_Deeper_Ladder(1, .1)

opt = optim.Adam(model.parameters(), lr=0.001)

nll = torch.nn.NLLLoss()
mse = torch.nn.MSELoss()

C = 1


def train_unsup():
    avg_loss = 0
    count = 0

    model.train()
    for batch_idx, (data, target) in enumerate(train_loader_unlabeled):
        data, target = Variable(data), Variable(target)
Пример #12
0
    model.load_tokendata(args.input_json)
    model.build_model(layertype=args.layer_type,
                      dropout=args.dropout,
                      num_layers=args.num_layers,
                      D=args.embedding_dim,
                      H=args.hidden_dim,
                      zoneout=args.zoneout)
else:
    model.load_json(args.load_model, clone_tensors=True)
    model.replace_tokendata(args.input_json)
print(model.layers)

logger.info('%s model with %d parameters' %
            ('Created' if args.load_model is None else 'Loaded',
             sum((p.numel() for p in model.parameters()))))
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
scheduler = optim.lr_scheduler.StepLR(optimizer, args.lrdecay_every,
                                      args.lrdecay_factor)
crit = nn.CrossEntropyLoss(reduction='none' if args.use_masks else 'mean')

logger.info('Loading data')

loader = DataLoader(filename=args.input_h5,
                    batch_size=args.batch_size,
                    seq_length=args.seq_length)

device = torch.device(args.device)
if args.layerdevices:
    for ld in args.layerdevices:
        start, end, device = ld.split(',')
        for layerid in range(int(start), int(end) + 1):
def train_model(config):
	num_epochs = 10

	class Flatten(nn.Module):
		def forward(self, input):
			return input.view(input.size(0), -1)

	class UnFlatten(nn.Module):
		def __init__(self, n_channels):
			super(UnFlatten, self).__init__()
			self.n_channels = n_channels
		def forward(self, input):
			size = int((input.size(1) // self.n_channels)**0.5)
			return input.view(input.size(0), self.n_channels, size, size)


	class VAE_Conv(nn.Module):
		"""
		https://github.com/vdumoulin/conv_arithmetic
		"""
		def __init__(self, z_dim= config['z'], img_channels=1, img_size=224):
			super(VAE_Conv, self).__init__()

			## encoder
			self.encoder = nn.Sequential(
				nn.Conv2d(img_channels, 32, (3,3), stride=2, padding=1),
				nn.ReLU(),
				nn.Conv2d(32, 64, (3,3), stride=2, padding=1),
				nn.ReLU(),
				nn.Conv2d(64, 128, (3,3), stride=2, padding=2),
				nn.ReLU(),
				nn.Conv2d(128, 64, (3,3), stride = 2, padding = 1),
				nn.ReLU(),
				nn.Conv2d(64, 32, (3,3), stride = 2, padding = 1),
				nn.ReLU(),
				Flatten()
			)

			## output size depends on input image size
			demo_input = torch.ones([1,img_channels,img_size,img_size])
			h_dim = self.encoder(demo_input).shape[1]
			print('h_dim', h_dim)
			## map to latent z
			# h_dim = convnet_to_dense_size(img_size, encoder_params)
			self.fc11 = nn.Linear(h_dim, z_dim)
			self.fc12 = nn.Linear(h_dim, z_dim)

			## decoder
			self.fc2 = nn.Linear(z_dim, h_dim)
			self.decoder = nn.Sequential(
				UnFlatten(32),
				nn.ConvTranspose2d(32, 64, (4,4), stride=2, padding=2),
				nn.ReLU(),
				nn.ConvTranspose2d(64,32, (5,5),stride = 2, padding = 1),
				nn.ReLU(),
				nn.ConvTranspose2d(32, 16, (5,5), stride=2, padding=2),
				nn.ReLU(),
				nn.ConvTranspose2d(16,8, (5,5), stride = 2, padding = 2),
				nn.ReLU(),
				nn.ConvTranspose2d(8,8, (5,5), stride = 2, padding = 2),
				nn.ReLU(),
				nn.ConvTranspose2d(8, img_channels, (4,4), stride=1, padding=2),
				nn.Sigmoid()
			)



		def encode(self, x):
			h = self.encoder(x)
			return self.fc11(h), self.fc12(h)

		def reparameterize(self, mu, logvar):
			std = torch.exp(0.5*logvar)
			eps = torch.randn_like(std)
			return mu + eps*std

		def decode(self, z):
			img = self.decoder(self.fc2(z))
			return img

		def forward(self, x):
			outputs = {}
			mu, logvar = self.encode(x)
			z = self.reparameterize(mu, logvar)
			outputs['x_hat'] = self.decode(z)
			outputs['mu'] = mu
			outputs['z'] = z
			outputs['log_var'] = logvar
			return outputs
	
	model = VAE_Conv().cuda()

	for name, param in model.named_parameters():
		if param.requires_grad == True:
				print("\t",name)


	#----------------------------------------------------------------------------------
	def ELBO_loss(y, t, mu, log_var):
		# Reconstruction error, log[p(x|z)]
		# Sum over features
		likelihood = -binary_cross_entropy(y, t, reduction="none")
		likelihood = likelihood.view(likelihood.size(0), -1).sum(1)

		# Regularization error: 
		# Kulback-Leibler divergence between approximate posterior, q(z|x)
		# and prior p(z) = N(z | mu, sigma*I).
		
		# In the case of the KL-divergence between diagonal covariance Gaussian and 
		# a standard Gaussian, an analytic solution exists. Using this excerts a lower
		# variance estimator of KL(q||p)
		kl = -0.5 * torch.sum(1 + log_var - mu**2 - torch.exp(log_var), dim=1)

		# Combining the two terms in the evidence lower bound objective (ELBO) 
		# mean over batch
		ELBO = torch.mean(likelihood) - torch.mean(kl)
		
		# notice minus sign as we want to maximise ELBO
		return -ELBO, kl.sum()

	optimizer = optim.Adam(model.parameters(), lr=0.001)
	loss_function = ELBO_loss

	#-----------------------------------------------------------------------------
	#Setup working directoy
	os.chdir('/zhome/b8/a/122402/Bachelor/Data')


	train_path='train/'
	valid_path='valid/'

	#Load data: 
	train_img_path=pd.read_csv('MURA-v1.1/train_image_paths.csv')
	valid_img_path=pd.read_csv('MURA-v1.1/valid_image_paths.csv')
	train_labels=pd.read_csv('MURA-v1.1/train_labeled_studies.csv')
	valid_labels=pd.read_csv('MURA-v1.1/valid_labeled_studies.csv')

	#Label anomaly
	train_img_path['Label']=train_img_path.Img_Path.apply(lambda x:1 if 'positive' in x else 0)
	valid_img_path['Label']=valid_img_path.Img_Path.apply(lambda x:1 if 'positive' in x else 0)

	#Label one-hot-encode for anatomy:
	train_img_path['ELBOW']= train_img_path.Img_Path.apply(lambda x:1 if 'ELBOW' in x else 0)
	train_img_path['SHOULDER']= train_img_path.Img_Path.apply(lambda x:1 if 'SHOULDER' in x else 0)
	train_img_path['FINGER']= train_img_path.Img_Path.apply(lambda x:1 if 'FINGER' in x else 0)
	train_img_path['FOREARM']= train_img_path.Img_Path.apply(lambda x:1 if 'FOREARM' in x else 0)
	train_img_path['HAND']= train_img_path.Img_Path.apply(lambda x:1 if 'HAND' in x else 0)
	train_img_path['HUMERUS']= train_img_path.Img_Path.apply(lambda x:1 if 'HUMERUS' in x else 0)
	train_img_path['WRIST']= train_img_path.Img_Path.apply(lambda x:1 if 'WRIST' in x else 0)

	valid_img_path['ELBOW']= valid_img_path.Img_Path.apply(lambda x:1 if 'ELBOW' in x else 0)
	valid_img_path['SHOULDER']= valid_img_path.Img_Path.apply(lambda x:1 if 'SHOULDER' in x else 0)
	valid_img_path['FINGER']= valid_img_path.Img_Path.apply(lambda x:1 if 'FINGER' in x else 0)
	valid_img_path['FOREARM']= valid_img_path.Img_Path.apply(lambda x:1 if 'FOREARM' in x else 0)
	valid_img_path['HAND']= valid_img_path.Img_Path.apply(lambda x:1 if 'HAND' in x else 0)
	valid_img_path['HUMERUS']= valid_img_path.Img_Path.apply(lambda x:1 if 'HUMERUS' in x else 0)
	valid_img_path['WRIST']= valid_img_path.Img_Path.apply(lambda x:1 if 'WRIST' in x else 0)	

	class Mura(Dataset):
		def __init__(self,df,root,phase, transform=None):
			self.df=df
			  
			self.root=root
			self.transform=transform
			
		def __len__(self):
			return len(self.df)
		
		def __getitem__(self,idx):
			img_name=self.df.iloc[idx,0]
			#img=Image.open(img_name, mode = 'r')
			img = cv2.imread(img_name)
			
			label_anomaly = torch.zeros(2, dtype = int)
			label_anomaly[self.df.iloc[idx,1].item()] = 1
			label_anatomy = np.asarray(self.df.iloc[idx,2:], dtype = np.int16)
			
			if self.transform is not None:
				img=self.transform(img)
			
			# Transformerer image til ImageNet std og Mean 
			# https://stats.stackexchange.com/questions/46429/transform-data-to-desired-mean-and-standard-deviation
			mean = torch.mean(img[0,:,:])
			sd = torch.std(img[0,:,:])
			img[0,:,:] = 0.485 + (img[0,:,:] - mean)*(0.229/sd)
			img[1,:,:] = 0.456 + (img[1,:,:] - mean)*(0.224/sd)
			img[2,:,:] = 0.406 + (img[2,:,:] - mean)*(0.225/sd)
			return img,label_anomaly, label_anatomy
		
	# Data augmentation and normalization for training
	# Just normalization for validation
	data_transforms = {
		'train': transforms.Compose([
			transforms.ToPILImage(),
			#transforms.RandomResizedCrop(224),
			#transforms.RandomHorizontalFlip(),
			transforms.Resize((224,224)),
			transforms.ToTensor()
			#transforms.Normalize([0.236, 0.236, 0.236], [0.109, 0.109, 0.109])
		]),
		'val': transforms.Compose([
			transforms.ToPILImage(),
			transforms.Resize((224,224)),
			#transforms.CenterCrop(224),
			transforms.ToTensor()
			#transforms.Normalize([0.236, 0.236, 0.236], [0.109, 0.109, 0.109])
		]),
	}

	#Dataloader:
	train_mura_dataset=Mura(df=train_img_path,root='./',phase = 'train', transform=data_transforms['train'])
	val_mura_dataset=Mura(df=valid_img_path,root='./',phase = 'val',transform=data_transforms['val'])
	train_loader=DataLoader(dataset=train_mura_dataset,batch_size=64,num_workers=2, shuffle = True )
	val_loader=DataLoader(dataset=val_mura_dataset,batch_size=64,num_workers=2, shuffle = True )

	dataloaders={
		'train':train_loader,
		'val':val_loader
	}
	dataset_sizes={
		'train':len(train_mura_dataset),
		'val':len(val_mura_dataset)
	}

	train_loss, valid_loss = [], []
	train_kl, valid_kl = [], []

	#-------------------------------------------------------------------------------
	for epoch in range(num_epochs):
		print(epoch)
		batch_loss, batch_kl = [], []
		batch_loss_val, batch_kl_val = [], []
		model.train()
		
		# Go through each batch in the training dataset using the loader
		# Note that y is not necessarily known as it is here
		for x, y, _ in dataloaders['train']:
			x = Variable(x[:,0,:,:].unsqueeze(1))
			
			# This is an alternative way of putting
			# a tensor on the GPU
			x = x.cuda()
			
			outputs = model(x)
			x_hat = outputs['x_hat']
			mu, log_var = outputs['mu'], outputs['log_var']

			elbo, kl = loss_function(x_hat, x, mu, log_var)
			
			optimizer.zero_grad()
			elbo.backward()
			optimizer.step()
			
			batch_loss.append(elbo.item())
			batch_kl.append(kl.item())

		train_loss.append(np.mean(batch_loss))
		train_kl.append(np.mean(batch_kl))

		# Evaluate, do not propagate gradients
		with torch.no_grad():
			model.eval()
			for x, y, _ in dataloaders['val']:
				x = Variable(x[:,0,:,:].unsqueeze(1))
			
				x = x.cuda()

				outputs = model(x)
				x_hat = outputs['x_hat']
				mu, log_var = outputs['mu'], outputs['log_var']
				#z = outputs["z"]

				elbo, kl = loss_function(x_hat, x, mu, log_var)

				batch_loss_val.append(elbo.item())
				batch_kl_val.append(kl.item())
			
			# We save the latent variable and reconstruction for later use
			# we will need them on the CPU to plot
			#x = x.to("cpu")
			#x_hat = x_hat.to("cpu")
			#z = z.detach().to("cpu").numpy()

			valid_loss.append(np.mean(batch_loss_val))
			valid_kl.append(np.mean(batch_kl_val))

			tune.track.log(mean_acc =valid_loss)
Пример #14
0
    with open(os.path.join(data_path, 'test_data.json'), 'r') as f:
        test_data = json.load(f)

    train_loader = torch.utils.data.DataLoader(
        AudiobookDataset(train_data),
        collate_fn=train_collate,
        batch_size=args.batch_size, shuffle=True, **kwargs)

    test_loader = torch.utils.data.DataLoader(
        AudiobookDataset(test_data),
        collate_fn=test_collate,
        batch_size=1, shuffle=False, **kwargs)

    model = Generator(hp.dim_neck, hp.dim_emb, hp.dim_pre, hp.freq).to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    current_epoch = 0
    if args.checkpoint:
        current_epoch = load_checkpoint(args.checkpoint, model, device, optimizer)
    
    checkpoint_dir = 'checkpoints'
    os.makedirs(checkpoint_dir, exist_ok=True)

    for epoch in range(current_epoch + 1, args.epochs + 1):
        print(f'epoch {epoch}')
        train(args, model, device, train_loader, optimizer, epoch)

        if epoch % 10 == 0:
            test(model, device, test_loader, checkpoint_dir, epoch)
            save_checkpoint(device, model, optimizer, checkpoint_dir, epoch)
Пример #15
0
    def __init__(self, args, embedding, vocab):
        super(coupled_vae_gan_trainer,self).__init__(args, embedding, vocab)

        self.use_variational = args.variational == "true"
        self.freeze_encoders = args.freeze_encoders == "true"
        self.use_feature_matching = args.use_feature_matching == "true"
        self.use_gradient_penalty = args.use_gradient_penalty == "true"
        self.double_fake_error = args.use_double_fake_error == "true"
        self.use_gumbel_generator = args.use_gumbel_generator == "true"
        self.cycle_consistency = args.cycle_consistency == "true"
        self.use_parallel = args.use_parallel_recon == "true"

        if args.cycle_consistency_criterion == "mse":
            self.cc_criterion = nn.MSELoss(size_average=True)
        else:
            self.cc_criterion = nn.HingeEmbeddingLoss()


        weight_sharing = args.weight_sharing == "true"
        # Setting up the networks
        self.networks["coupled_vae"] = CoupledVAE(embedding,
                                                  len(vocab),
                                                  hidden_size=args.hidden_size,
                                                  latent_size=args.latent_size,
                                                  batch_size = args.batch_size,
                                                  img_dimension= args.crop_size,
                                                  mask= args.common_emb_ratio,
                                                  use_variational= self.use_variational,
                                                  weight_sharing= weight_sharing,
                                                  use_gumbel_generator = self.use_gumbel_generator)

        self.networks["img_discriminator"] = ImgDiscriminator(args.crop_size, batch_size = args.batch_size,
                                                              mask = args.common_emb_ratio,
                                                              latent_size= args.latent_size)

        self.networks["txt_discriminator"] = SeqDiscriminator(embedding, len(vocab),
                                                              hidden_size=args.hidden_size,
                                                              mask =args.common_emb_ratio,
                                                              latent_size= args.latent_size,
                                                              use_gumbel_generator = self.use_gumbel_generator)

        # Initialiaze the weights
        self.networks["img_discriminator"].apply(weights_init)
        self.networks["coupled_vae"].apply(weights_init)
        # Setting up the optimizers
        self.optimizers["coupled_vae"] = optim.Adam(valid_params(self.networks["coupled_vae"].parameters()),\
                                                    lr=args.learning_rate, betas=(0.5, 0.999)) #, weight_decay=0.00001)

        self.optimizers["discriminators"] = optim.Adam(chain(self.networks["img_discriminator"].parameters(), \
                                                       valid_params(self.networks["txt_discriminator"].parameters()), ), \
                                                       lr=args.learning_rate, betas=(0.5, 0.999))  # , weight_decay=0.00001)
        self.masked_latent_size = int(args.common_emb_ratio * args.latent_size)

        # Setting up the losses
        losses = ["Ls_D_img", "Ls_D_txt",
                                  "Ls_G_img",  "Ls_G_txt",
                                  "Ls_D_img_rl", "Ls_D_txt_rl",
                                  "Ls_D_img_fk","Ls_D_txt_fk",
                                  "Ls_RC_img", "Ls_RC_txt", "Ls_RC_txt_fk",
                                  # "Ls_KL_img", "Ls_KL_txt",
                                  "Ls_VAE"]



        if self.use_gradient_penalty:
            losses += ["Ls_GP_img", "Ls_GP_txt"]

        if self.use_feature_matching:
            losses.append("Ls_FM_img")

        if self.double_fake_error:
            losses.append("Ls_D_img_wr")
            losses.append("Ls_D_txt_wr")

        if self.cycle_consistency:
            losses.append("Ls_CC_txt")
            losses.append("Ls_CC_img")

        self.create_losses_meter(losses)

        # Set Evaluation Metrics
        metrics = ["Ls_RC_txt",
                                  "Ls_D_img_rl", "Ls_D_txt_rl",
                                  "Ls_D_img_fk","Ls_D_txt_fk",
                                  "BLEU", "TxtL2"]

        if self.double_fake_error:
            metrics.append("Ls_D_img_wr")
            metrics.append("Ls_D_txt_wr")

        self.create_metrics_meter(metrics)
        self.one = torch.FloatTensor([1])
        self.mone = self.one * -1
        if args.cuda:
            self.one = self.one.cuda()
            self.mone = self.mone.cuda()

        self.recon_criterion = nn.MSELoss()

        self.nets_to_cuda()

        self.train_swapped = True # It's a GAN!!
        self.step = 0

        if args.load_vae != "NONE":
            self.load_vae(args.load_vae)

        self.epoch_max_len = Variable(torch.zeros(self.args.batch_size).long().cuda())
Пример #16
0
r0 = int(list(test_batch_y).index(1))
r1 = int(list(test_batch_y).index(2))
r2 = int(list(test_batch_y).index(3))
r3 = int(len(list(test_batch_y)))
r = [0, r0, r1, r2, r3]

test_batch_mask = np.concatenate([
    test_batch_mask["ribosome"],
    test_batch_mask["proteasome_s"],
    test_batch_mask["TRiC"],
    test_batch_mask["membrane"],
])

net_disc = Disc().cuda()
net_disc.train()
optimizer_disc = optim.Adam(net_disc.parameters(), lr=1e-5, betas=(0.9, 0.999))
criterion_disc = nn.BCELoss().cuda()

minval = 0
maxval = 0
min_l = 1e10

epochs = 5

try:
    n_batch = int(sys.argv[1])
    num_iterations = 60 * 32 // n_batch
except:
    n_batch = 32
    num_iterations = 60
Пример #17
0
                                          suffix_str='base0')
    logger = Logger(home_path)

    hbest_path = os.path.join(home_path, 'model.best.pth')

    if not os.path.isdir(home_path):
        os.makedirs(home_path)
    best_gen_loss = 9999
    if not os.path.isfile(hbest_path + ".done"):
        print(colored('Training from scratch', 'green'))
        best_loss = -1

        optimizerG = optim.Adam([{
            'params': model.GenX.parameters()
        }, {
            'params': model.GenZ.parameters()
        }],
                                lr=args.lr,
                                betas=(args.beta1, args.beta2))

        optimizerD = optim.Adam([{
            'params': model.DisZ.parameters()
        }, {
            'params': model.DisX.parameters()
        }, {
            'params': model.DisXZ.parameters()
        }],
                                lr=args.lr,
                                betas=(args.beta1, args.beta2))
        criterion = nn.BCELoss()
        for epoch in range(1, 100 + 1):
Пример #18
0
    return [resize_convert(x) for x in args]


model_name = 'UNet_MaskRCNN_correction'
results_root_path = '/media/tdanka/B8703F33703EF828/tdanka/results'
train_dataset_loc = '/media/tdanka/B8703F33703EF828/tdanka/UNet_correction/TIVADAR/rcnn2unet_split/train/loc.csv'
validate_dataset_loc = '/media/tdanka/B8703F33703EF828/tdanka/UNet_correction/TIVADAR/rcnn2unet_split/test/loc.csv'
class_weights = [1.0, 5.0, 25.0, 25.0]

tf = make_transform_RCNN(size=(128, 128), p_flip=0.5, long_mask=True)
train_dataset = TrainWithRCNNMask(train_dataset_loc, transform=tf)
validate_dataset = TrainWithRCNNMask(validate_dataset_loc, transform=tf)
test_dataset = TrainFromFolder(train_dataset_loc, transform=T.ToTensor(), remove_alpha=True)

net = torch.load('/media/tdanka/B8703F33703EF828/tdanka/results/UNet_MaskRCNN_correction/UNet_MaskRCNN_correction')#UNet(4, 4, softmax=True)
loss = CrossEntropyLoss2d(weight=torch.from_numpy(np.array(class_weights)).float())
n_epochs = 20
lr_milestones = [int(p*n_epochs) for p in [0.3, 0.7, 0.9]]
optimizer = optim.Adam(net.parameters(), lr=1e-2)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, lr_milestones)
#scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10, factor=0.1, verbose=True)

model = RCNNCorrection(
    model=net, loss=loss, optimizer=optimizer, scheduler=scheduler,
    model_name=model_name, results_root_path=results_root_path
)

#model.train_model(train_dataset, n_epochs=n_epochs, n_batch=16, verbose=False, validation_dataset=validate_dataset)
model.visualize(train_dataset, n_inst=None)

Пример #19
0


#step 6定义优化器
net = NET().to(DEVICE)

# torch.save(net, "net.pkl")   # save entire net
# torch.save(net.state_dict(), "net_params.pkl")   # save parameters
# def restore_net():
    # net_2 = torch.load("net.pkl")             #load entire net
# def restore_params():
    # net_3 = #和被提取网络格式一样的形式
    # net_3.load_state_dict(torch.load("net_params.pkl"))             #load params to a new net without params beforesabe save sa#sdddsdawdsdawsd wasdwasdwa


optimizer = optim.Adam(net.parameters())
#step 7定义训练函数方法
def train_model(net, device, train_loader, optimizer, epoch):
    net.train()#如果模型中有BN层(Batch Normalization)和Dropout,需要在训练时添加model.train(),在测试时添加model.eval()。其中model.train()是保证BN层用每一批数据的均值和方差,而model.eval()是保证BN用全部训练数据的均值和方差;而对于Dropout,model.train()是随机取一部分网络连接来训练更新参数,而model.eval()是利用到了所有网络连接。
    for batch_index, (inputs, labels) in enumerate(train_loader):

        inputs,labels = inputs.to(device), labels.to(device) # 部署至device
        optimizer.zero_grad() # 初始化梯度为0
        output = net(inputs) # 预测
        loss = F.cross_entropy(output, labels) # 以loss函数为反馈信号对权重值进行调节,该调节由optimizer完成,它实现了back propagation(反向传播)算法,
                                               # cross_entropy函数实现了每个prediction的概率值
                                               # output表示能表征预测值的矩阵, target表示0~10的标签
        loss.backward()#反向传播
        optimizer.step()#参数优化
        if batch_index % 3000 == 0: # MNIST有60000个测试集,60000/10=6000,6000/3000=2,每轮可以检测2个loss
            print("Train Epoch: {} \t Loss: {:.6f}".format(epoch,loss.item()))
Пример #20
0
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        features = self.features(x)
        out = F.relu(features, inplace=True)
        out = F.avg_pool2d(out, kernel_size=3, stride=1).view(features.size(0), -1)
        out = self.classifier(out)

        return out



model = DenseNet()
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss().cuda()

if torch.cuda.device_count() > 0:
	print("USE", torch.cuda.device_count(), "GPUs!")
	model = nn.DataParallel(model)
	cudnn.benchmark = True
else:
	print("USE ONLY CPU!")

if torch.cuda.is_available():
	model.cuda()


def train(epoch):
	model.train()
Пример #21
0
args.cuda = not args.no_cuda and torch.cuda.is_available()

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

model = SoAGREE(128, 128, args.dropout)
if args.pretrained:
    model.load_state_dict(torch.load(args.pretrained))
if args.cuda:
    model.cuda()

optimizer = optim.Adam(model.parameters(),
                       lr=args.lr,
                       weight_decay=args.weight_decay)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2)
loss = nn.BCELoss()
dataset = GATDataset(args)
train_loader, val_loader = dataset.get_data_loader(torch.device('cpu'),
                                                   ['train', 'val'])
T = dataset.T
N = T + len(val_loader)
repo_teams = [
    dataset.repo_teams_zero[r] + dataset.repo_teams_one[r]
    for r in range(T, N)
]

del dataset
Пример #22
0
def train(hyp):
    epochs = opt.epochs  # 300
    batch_size = opt.batch_size  # 64
    weights = opt.weights  # initial training weights

    # Configure
    init_seeds(1)
    with open(opt.data) as f:
        data_dict = yaml.load(f, Loader=yaml.FullLoader)  # model dict
    train_path = data_dict['train']
    test_path = data_dict['val']
    nc = 1 if opt.single_cls else int(data_dict['nc'])  # number of classes

    # Remove previous results
    for f in glob.glob('*_batch*.jpg') + glob.glob(results_file):
        os.remove(f)

    # Create model
    model = Model(opt.cfg).to(device)
    assert model.md['nc'] == nc, '%s nc=%g classes but %s nc=%g classes' % (
        opt.data, nc, opt.cfg, model.md['nc'])
    model.names = data_dict['names']

    # Image sizes
    gs = int(max(model.stride))  # grid size (max stride)
    imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size
                         ]  # verify imgsz are gs-multiples

    # Optimizer
    nbs = 64  # nominal batch size
    accumulate = max(round(nbs / batch_size),
                     1)  # accumulate loss before optimizing
    hyp['weight_decay'] *= batch_size * accumulate / nbs  # scale weight_decay
    pg0, pg1, pg2 = [], [], []  # optimizer parameter groups
    for k, v in model.named_parameters():
        if v.requires_grad:
            if '.bias' in k:
                pg2.append(v)  # biases
            elif '.weight' in k and '.bn' not in k:
                pg1.append(v)  # apply weight decay
            else:
                pg0.append(v)  # all else

    optimizer = optim.Adam(pg0, lr=hyp['lr0']) if opt.adam else \
        optim.SGD(pg0, lr=hyp['lr0'], momentum=hyp['momentum'], nesterov=True)
    optimizer.add_param_group({
        'params': pg1,
        'weight_decay': hyp['weight_decay']
    })  # add pg1 with weight_decay
    optimizer.add_param_group({'params': pg2})  # add pg2 (biases)
    print('Optimizer groups: %g .bias, %g conv.weight, %g other' %
          (len(pg2), len(pg1), len(pg0)))
    del pg0, pg1, pg2

    # Load Model
    google_utils.attempt_download(weights)
    start_epoch, best_fitness = 0, 0.0
    if weights.endswith('.pt'):  # pytorch format
        ckpt = torch.load(weights, map_location=device)  # load checkpoint

        # load model
        try:
            ckpt['model'] = {
                k: v
                for k, v in ckpt['model'].float().state_dict().items()
                if model.state_dict()[k].shape == v.shape
            }  # to FP32, filter
            model.load_state_dict(ckpt['model'], strict=False)
        except KeyError as e:
            s = "%s is not compatible with %s. Specify --weights '' or specify a --cfg compatible with %s." \
                % (opt.weights, opt.cfg, opt.weights)
            raise KeyError(s) from e

        # load optimizer
        if ckpt['optimizer'] is not None:
            optimizer.load_state_dict(ckpt['optimizer'])
            best_fitness = ckpt['best_fitness']

        # load results
        if ckpt.get('training_results') is not None:
            with open(results_file, 'w') as file:
                file.write(ckpt['training_results'])  # write results.txt

        start_epoch = ckpt['epoch'] + 1
        del ckpt

    # Mixed precision training https://github.com/NVIDIA/apex
    if mixed_precision:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level='O1',
                                          verbosity=0)

    # Scheduler https://arxiv.org/pdf/1812.01187.pdf
    lf = lambda x: ((
        (1 + math.cos(x * math.pi / epochs)) / 2)**1.0) * 0.9 + 0.1  # cosine
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
    scheduler.last_epoch = start_epoch - 1  # do not move
    # https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822
    # plot_lr_scheduler(optimizer, scheduler, epochs)

    # Initialize distributed training
    if device.type != 'cpu' and torch.cuda.device_count(
    ) > 1 and torch.distributed.is_available():
        dist.init_process_group(
            backend='nccl',  # distributed backend
            init_method='tcp://127.0.0.1:9999',  # init method
            world_size=1,  # number of nodes
            rank=0)  # node rank
        model = torch.nn.parallel.DistributedDataParallel(model)
        # pip install torch==1.4.0+cu100 torchvision==0.5.0+cu100 -f https://download.pytorch.org/whl/torch_stable.html

    # Dataset
    dataset = LoadImagesAndLabels(
        train_path,
        imgsz,
        batch_size,
        augment=True,
        hyp=hyp,  # augmentation hyperparameters
        rect=opt.rect,  # rectangular training
        cache_images=opt.cache_images,
        single_cls=opt.single_cls,
        stride=gs)
    mlc = np.concatenate(dataset.labels, 0)[:, 0].max()  # max label class
    assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Correct your labels or your model.' % (
        mlc, nc, opt.cfg)

    # Dataloader
    batch_size = min(batch_size, len(dataset))
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0,
              8])  # number of workers
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=nw,
        shuffle=not opt.
        rect,  # Shuffle=True unless rectangular training is used
        pin_memory=True,
        collate_fn=dataset.collate_fn)

    # Testloader
    testloader = torch.utils.data.DataLoader(LoadImagesAndLabels(
        test_path,
        imgsz_test,
        batch_size,
        hyp=hyp,
        rect=True,
        cache_images=opt.cache_images,
        single_cls=opt.single_cls,
        stride=gs),
                                             batch_size=batch_size,
                                             num_workers=nw,
                                             pin_memory=True,
                                             collate_fn=dataset.collate_fn)

    # Model parameters
    hyp['cls'] *= nc / 80.  # scale coco-tuned hyp['cls'] to current dataset
    model.nc = nc  # attach number of classes to model
    model.hyp = hyp  # attach hyperparameters to model
    model.gr = 1.0  # giou loss ratio (obj_loss = 1.0 or giou)
    model.class_weights = labels_to_class_weights(dataset.labels, nc).to(
        device)  # attach class weights

    # Class frequency
    labels = np.concatenate(dataset.labels, 0)
    c = torch.tensor(labels[:, 0])  # classes
    # cf = torch.bincount(c.long(), minlength=nc) + 1.
    # model._initialize_biases(cf.to(device))
    if tb_writer:
        plot_labels(labels)
        tb_writer.add_histogram('classes', c, 0)

    # Check anchors
    if not opt.noautoanchor:
        check_anchors(dataset, model=model, thr=hyp['anchor_t'], imgsz=imgsz)

    # Exponential moving average
    ema = torch_utils.ModelEMA(model)

    # Start training
    t0 = time.time()
    nb = len(dataloader)  # number of batches
    n_burn = max(3 * nb,
                 1e3)  # burn-in iterations, max(3 epochs, 1k iterations)
    maps = np.zeros(nc)  # mAP per class
    results = (
        0, 0, 0, 0, 0, 0, 0
    )  # 'P', 'R', 'mAP', 'F1', 'val GIoU', 'val Objectness', 'val Classification'
    print('Image sizes %g train, %g test' % (imgsz, imgsz_test))
    print('Using %g dataloader workers' % nw)
    print('Starting training for %g epochs...' % epochs)
    # torch.autograd.set_detect_anomaly(True)
    for epoch in range(
            start_epoch, epochs
    ):  # epoch ------------------------------------------------------------------
        model.train()

        # Update image weights (optional)
        if dataset.image_weights:
            w = model.class_weights.cpu().numpy() * (1 -
                                                     maps)**2  # class weights
            image_weights = labels_to_image_weights(dataset.labels,
                                                    nc=nc,
                                                    class_weights=w)
            dataset.indices = random.choices(range(dataset.n),
                                             weights=image_weights,
                                             k=dataset.n)  # rand weighted idx

        mloss = torch.zeros(4, device=device)  # mean losses
        print(('\n' + '%10s' * 8) % ('Epoch', 'gpu_mem', 'GIoU', 'obj', 'cls',
                                     'total', 'targets', 'img_size'))
        pbar = tqdm(enumerate(dataloader), total=nb)  # progress bar
        for i, (
                imgs, targets, paths, _
        ) in pbar:  # batch -------------------------------------------------------------
            ni = i + nb * epoch  # number integrated batches (since train start)
            imgs = imgs.to(device).float(
            ) / 255.0  # uint8 to float32, 0 - 255 to 0.0 - 1.0

            # Burn-in
            if ni <= n_burn:
                xi = [0, n_burn]  # x interp
                # model.gr = np.interp(ni, xi, [0.0, 1.0])  # giou loss ratio (obj_loss = 1.0 or giou)
                accumulate = max(
                    1,
                    np.interp(ni, xi, [1, nbs / batch_size]).round())
                for j, x in enumerate(optimizer.param_groups):
                    # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
                    x['lr'] = np.interp(
                        ni, xi,
                        [0.1 if j == 2 else 0.0, x['initial_lr'] * lf(epoch)])
                    if 'momentum' in x:
                        x['momentum'] = np.interp(ni, xi,
                                                  [0.9, hyp['momentum']])

            # Multi-scale
            if opt.multi_scale:
                sz = random.randrange(imgsz * 0.5,
                                      imgsz * 1.5 + gs) // gs * gs  # size
                sf = sz / max(imgs.shape[2:])  # scale factor
                if sf != 1:
                    ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]
                          ]  # new shape (stretched to gs-multiple)
                    imgs = F.interpolate(imgs,
                                         size=ns,
                                         mode='bilinear',
                                         align_corners=False)

            # Forward
            pred = model(imgs)

            # Loss
            loss, loss_items = compute_loss(pred, targets.to(device), model)
            if not torch.isfinite(loss):
                print('WARNING: non-finite loss, ending training ', loss_items)
                return results

            # Backward
            if mixed_precision:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            # Optimize
            if ni % accumulate == 0:
                optimizer.step()
                optimizer.zero_grad()
                ema.update(model)

            # Print
            mloss = (mloss * i + loss_items) / (i + 1)  # update mean losses
            mem = '%.3gG' % (torch.cuda.memory_cached() /
                             1E9 if torch.cuda.is_available() else 0)  # (GB)
            s = ('%10s' * 2 + '%10.4g' * 6) % ('%g/%g' % (epoch, epochs - 1),
                                               mem, *mloss, targets.shape[0],
                                               imgs.shape[-1])
            pbar.set_description(s)

            # Plot
            if ni < 3:
                f = 'train_batch%g.jpg' % ni  # filename
                result = plot_images(images=imgs,
                                     targets=targets,
                                     paths=paths,
                                     fname=f)
                if tb_writer and result is not None:
                    tb_writer.add_image(f,
                                        result,
                                        dataformats='HWC',
                                        global_step=epoch)
                    # tb_writer.add_graph(model, imgs)  # add model to tensorboard

            # end batch ------------------------------------------------------------------------------------------------

        # Scheduler
        scheduler.step()

        # mAP
        ema.update_attr(model)
        final_epoch = epoch + 1 == epochs
        if not opt.notest or final_epoch:  # Calculate mAP
            results, maps, times = test.test(
                opt.data,
                batch_size=batch_size,
                imgsz=imgsz_test,
                save_json=final_epoch
                and opt.data.endswith(os.sep + 'coco.yaml'),
                model=ema.ema,
                single_cls=opt.single_cls,
                dataloader=testloader)

        # Write
        with open(results_file, 'a') as f:
            f.write(s + '%10.4g' * 7 % results +
                    '\n')  # P, R, mAP, F1, test_losses=(GIoU, obj, cls)
        if len(opt.name) and opt.bucket:
            os.system('gsutil cp results.txt gs://%s/results/results%s.txt' %
                      (opt.bucket, opt.name))

        # Tensorboard
        if tb_writer:
            tags = [
                'train/giou_loss', 'train/obj_loss', 'train/cls_loss',
                'metrics/precision', 'metrics/recall', 'metrics/mAP_0.5',
                'metrics/F1', 'val/giou_loss', 'val/obj_loss', 'val/cls_loss'
            ]
            for x, tag in zip(list(mloss[:-1]) + list(results), tags):
                tb_writer.add_scalar(tag, x, epoch)

        # Update best mAP
        fi = fitness(np.array(results).reshape(
            1, -1))  # fitness_i = weighted combination of [P, R, mAP, F1]
        if fi > best_fitness:
            best_fitness = fi

        # Save model
        save = (not opt.nosave) or (final_epoch and not opt.evolve)
        if save:
            with open(results_file, 'r') as f:  # create checkpoint
                ckpt = {
                    'epoch': epoch,
                    'best_fitness': best_fitness,
                    'training_results': f.read(),
                    'model':
                    ema.ema.module if hasattr(model, 'module') else ema.ema,
                    'optimizer':
                    None if final_epoch else optimizer.state_dict()
                }

            # Save last, best and delete
            torch.save(ckpt, last)
            if (best_fitness == fi) and not final_epoch:
                torch.save(ckpt, best)
            del ckpt

        # end epoch ----------------------------------------------------------------------------------------------------
    # end training

    n = opt.name
    if len(n):
        n = '_' + n if not n.isnumeric() else n
        fresults, flast, fbest = 'results%s.txt' % n, wdir + 'last%s.pt' % n, wdir + 'best%s.pt' % n
        for f1, f2 in zip([wdir + 'last.pt', wdir + 'best.pt', 'results.txt'],
                          [flast, fbest, fresults]):
            if os.path.exists(f1):
                os.rename(f1, f2)  # rename
                ispt = f2.endswith('.pt')  # is *.pt
                strip_optimizer(f2) if ispt else None  # strip optimizer
                os.system('gsutil cp %s gs://%s/weights' % (
                    f2, opt.bucket)) if opt.bucket and ispt else None  # upload

    if not opt.evolve:
        plot_results()  # save as results.png
    print('%g epochs completed in %.3f hours.\n' % (epoch - start_epoch + 1,
                                                    (time.time() - t0) / 3600))
    dist.destroy_process_group(
    ) if device.type != 'cpu' and torch.cuda.device_count() > 1 else None
    torch.cuda.empty_cache()
    return results
Пример #23
0
def main(args, logger):
    # trn_df = pd.read_csv(f'{MNT_DIR}/inputs/origin/train.csv')
    trn_df = pd.read_pickle(f'{MNT_DIR}/inputs/nes_info/trn_df.pkl')
    trn_df['is_original'] = 1
    # raw_pseudo_df = pd.read_csv('./mnt/inputs/pseudos/top2_e078_e079_e080_e081_e082_e083/raw_pseudo_tst_df.csv')
    # half_opt_pseudo_df = pd.read_csv('./mnt/inputs/pseudos/top2_e078_e079_e080_e081_e082_e083/half_opt_pseudo_tst_df.csv')
    # opt_pseudo_df = pd.read_csv('./mnt/inputs/pseudos/top2_e078_e079_e080_e081_e082_e083/opt_pseudo_tst_df.csv')

    # clean texts
    # trn_df = clean_data(trn_df, ['question_title', 'question_body', 'answer'])

    # load additional tokens
    # with open('./mnt/inputs/nes_info/trn_over_10_vocab.pkl', 'rb') as fin:
    #     additional_tokens = pickle.load(fin)

    gkf = GroupKFold(n_splits=5).split(
        X=trn_df.question_body,
        groups=trn_df.question_body_le,
    )

    histories = {
        'trn_loss': {},
        'val_loss': {},
        'val_metric': {},
        'val_metric_raws': {},
    }
    loaded_fold = -1
    loaded_epoch = -1
    if args.checkpoint:
        histories, loaded_fold, loaded_epoch = load_checkpoint(args.checkpoint)

    fold_best_metrics = []
    fold_best_metrics_raws = []
    for fold, (trn_idx, val_idx) in enumerate(gkf):
        if fold < loaded_fold:
            fold_best_metrics.append(np.max(histories["val_metric"][fold]))
            fold_best_metrics_raws.append(
                histories["val_metric_raws"][fold][np.argmax(
                    histories["val_metric"][fold])])
            continue
        sel_log(
            f' --------------------------- start fold {fold} --------------------------- ',
            logger)
        fold_trn_df = trn_df.iloc[trn_idx]  # .query('is_original == 1')
        fold_trn_df = fold_trn_df.drop(['is_original', 'question_body_le'],
                                       axis=1)
        # use only original row
        fold_val_df = trn_df.iloc[val_idx].query('is_original == 1')
        fold_val_df = fold_val_df.drop(['is_original', 'question_body_le'],
                                       axis=1)
        if args.debug:
            fold_trn_df = fold_trn_df.sample(100, random_state=71)
            fold_val_df = fold_val_df.sample(100, random_state=71)
        temp = pd.Series(
            list(
                itertools.chain.from_iterable(
                    fold_trn_df.question_title.apply(lambda x: x.split(' ')) +
                    fold_trn_df.question_body.apply(lambda x: x.split(' ')) +
                    fold_trn_df.answer.apply(lambda x: x.split(' '))))
        ).value_counts()
        tokens = temp[temp >= 10].index.tolist()
        # tokens = []
        tokens = [
            'CAT_TECHNOLOGY'.casefold(),
            'CAT_STACKOVERFLOW'.casefold(),
            'CAT_CULTURE'.casefold(),
            'CAT_SCIENCE'.casefold(),
            'CAT_LIFE_ARTS'.casefold(),
        ]  #  + additional_tokens

        # fold_trn_df = pd.concat([fold_trn_df, opt_pseudo_df, half_opt_pseudo_df], axis=0)

        trn_dataset = QUESTDataset(
            df=fold_trn_df,
            mode='train',
            tokens=tokens,
            augment=[],
            tokenizer_type=TOKENIZER_TYPE,
            pretrained_model_name_or_path=TOKENIZER_PRETRAIN,
            do_lower_case=DO_LOWER_CASE,
            LABEL_COL=LABEL_COL,
            t_max_len=T_MAX_LEN,
            q_max_len=Q_MAX_LEN,
            a_max_len=A_MAX_LEN,
            tqa_mode=TQA_MODE,
            TBSEP='[TBSEP]',
            pos_id_type='arange',
            MAX_SEQUENCE_LENGTH=MAX_SEQ_LEN,
        )
        # update token
        trn_sampler = RandomSampler(data_source=trn_dataset)
        trn_loader = DataLoader(trn_dataset,
                                batch_size=BATCH_SIZE,
                                sampler=trn_sampler,
                                num_workers=os.cpu_count(),
                                worker_init_fn=lambda x: np.random.seed(),
                                drop_last=True,
                                pin_memory=True)
        val_dataset = QUESTDataset(
            df=fold_val_df,
            mode='valid',
            tokens=tokens,
            augment=[],
            tokenizer_type=TOKENIZER_TYPE,
            pretrained_model_name_or_path=TOKENIZER_PRETRAIN,
            do_lower_case=DO_LOWER_CASE,
            LABEL_COL=LABEL_COL,
            t_max_len=T_MAX_LEN,
            q_max_len=Q_MAX_LEN,
            a_max_len=A_MAX_LEN,
            tqa_mode=TQA_MODE,
            TBSEP='[TBSEP]',
            pos_id_type='arange',
            MAX_SEQUENCE_LENGTH=MAX_SEQ_LEN,
        )
        val_sampler = RandomSampler(data_source=val_dataset)
        val_loader = DataLoader(val_dataset,
                                batch_size=BATCH_SIZE,
                                sampler=val_sampler,
                                num_workers=os.cpu_count(),
                                worker_init_fn=lambda x: np.random.seed(),
                                drop_last=False,
                                pin_memory=True)

        fobj = BCEWithLogitsLoss()
        state_dict = BertModel.from_pretrained(MODEL_PRETRAIN).state_dict()
        model = BertModelForBinaryMultiLabelClassifier(
            num_labels=len(LABEL_COL),
            config_path=MODEL_CONFIG_PATH,
            state_dict=state_dict,
            token_size=len(trn_dataset.tokenizer),
            MAX_SEQUENCE_LENGTH=MAX_SEQ_LEN,
            cat_last_layer_num=1,
            do_ratio=0.2,
        )
        optimizer = optim.Adam(model.parameters(), lr=3e-5)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                         T_max=MAX_EPOCH,
                                                         eta_min=1e-5)

        # load checkpoint model, optim, scheduler
        if args.checkpoint and fold == loaded_fold:
            load_checkpoint(args.checkpoint, model, optimizer, scheduler)

        for epoch in tqdm(list(range(MAX_EPOCH))):
            if fold <= loaded_fold and epoch <= loaded_epoch:
                continue
            if epoch < 1:
                model.freeze_unfreeze_bert(freeze=True, logger=logger)
            else:
                model.freeze_unfreeze_bert(freeze=False, logger=logger)
            model = DataParallel(model)
            model = model.to(DEVICE)
            trn_loss = train_one_epoch(model, fobj, optimizer, trn_loader,
                                       DEVICE)
            val_loss, val_metric, val_metric_raws, val_y_preds, val_y_trues, val_qa_ids = test(
                model, fobj, val_loader, DEVICE, mode='valid')

            scheduler.step()
            if fold in histories['trn_loss']:
                histories['trn_loss'][fold].append(trn_loss)
            else:
                histories['trn_loss'][fold] = [
                    trn_loss,
                ]
            if fold in histories['val_loss']:
                histories['val_loss'][fold].append(val_loss)
            else:
                histories['val_loss'][fold] = [
                    val_loss,
                ]
            if fold in histories['val_metric']:
                histories['val_metric'][fold].append(val_metric)
            else:
                histories['val_metric'][fold] = [
                    val_metric,
                ]
            if fold in histories['val_metric_raws']:
                histories['val_metric_raws'][fold].append(val_metric_raws)
            else:
                histories['val_metric_raws'][fold] = [
                    val_metric_raws,
                ]

            logging_val_metric_raws = ''
            for val_metric_raw in val_metric_raws:
                logging_val_metric_raws += f'{float(val_metric_raw):.4f}, '

            sel_log(
                f'fold : {fold} -- epoch : {epoch} -- '
                f'trn_loss : {float(trn_loss.detach().to("cpu").numpy()):.4f} -- '
                f'val_loss : {float(val_loss.detach().to("cpu").numpy()):.4f} -- '
                f'val_metric : {float(val_metric):.4f} -- '
                f'val_metric_raws : {logging_val_metric_raws}', logger)
            model = model.to('cpu')
            model = model.module
            save_checkpoint(
                f'{MNT_DIR}/checkpoints/{EXP_ID}/{fold}',
                model,
                optimizer,
                scheduler,
                histories,
                val_y_preds,
                val_y_trues,
                val_qa_ids,
                fold,
                epoch,
                val_loss,
                val_metric,
            )
        fold_best_metrics.append(np.max(histories["val_metric"][fold]))
        fold_best_metrics_raws.append(
            histories["val_metric_raws"][fold][np.argmax(
                histories["val_metric"][fold])])
        save_and_clean_for_prediction(f'{MNT_DIR}/checkpoints/{EXP_ID}/{fold}',
                                      trn_dataset.tokenizer,
                                      clean=False)
        del model

    # calc training stats
    fold_best_metric_mean = np.mean(fold_best_metrics)
    fold_best_metric_std = np.std(fold_best_metrics)
    fold_stats = f'{EXP_ID} : {fold_best_metric_mean:.4f} +- {fold_best_metric_std:.4f}'
    sel_log(fold_stats, logger)
    send_line_notification(fold_stats)

    fold_best_metrics_raws_mean = np.mean(fold_best_metrics_raws, axis=0)
    fold_raw_stats = ''
    for metric_stats_raw in fold_best_metrics_raws_mean:
        fold_raw_stats += f'{float(metric_stats_raw):.4f},'
    sel_log(fold_raw_stats, logger)
    send_line_notification(fold_raw_stats)

    sel_log('now saving best checkpoints...', logger)
Пример #24
0
# Print the model
print(netD)

# Initialize BCELoss function
criterion = nn.BCELoss()

# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

# Commented out IPython magic to ensure Python compatibility.
# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
Пример #25
0
def main():
    # Training settings
    # Note: Hyper-parameters need to be tuned in order to obtain results reported in the paper.
    parser = argparse.ArgumentParser(
        description=
        'PyTorch graph convolutional neural net for whole-graph classification'
    )
    parser.add_argument('--dataset',
                        type=str,
                        default="PROTEINS",
                        help='name of dataset (default: MUTAG)')
    parser.add_argument('--device',
                        type=int,
                        default=0,
                        help='which gpu to use if any (default: 0)')
    parser.add_argument('--batch_size',
                        type=int,
                        default=64,
                        help='input batch size for training (default: 32)')
    parser.add_argument(
        '--iters_per_epoch',
        type=int,
        default=50,
        help='number of iterations per each epoch (default: 50)')
    parser.add_argument('--epochs',
                        type=int,
                        default=10000,
                        help='number of epochs to train (default: 350)')
    parser.add_argument('--lr',
                        type=float,
                        default=0.05,
                        help='learning rate (default: 0.01)')
    parser.add_argument(
        '--seed',
        type=int,
        default=42,
        help='random seed for splitting the dataset into 10 (default: 0)')
    parser.add_argument(
        '--fold_idx',
        type=int,
        default=1,
        help='the index of fold in 10-fold validation. Should be less then 10.'
    )
    parser.add_argument(
        '--num_layers',
        type=int,
        default=2,
        help='number of layers INCLUDING the input one (default: 5)')
    parser.add_argument(
        '--num_mlp_layers',
        type=int,
        default=2,
        help=
        'number of layers for MLP EXCLUDING the input one (default: 2). 1 means linear model.'
    )
    parser.add_argument('--hidden_dim',
                        type=int,
                        default=64,
                        help='number of hidden units (default: 64)')
    parser.add_argument('--rank',
                        type=int,
                        default=64,
                        help='number of hidden units (default: 64)')
    parser.add_argument('--dropout',
                        type=float,
                        default=0.5,
                        help='final layer dropout (default: 0.5)')
    parser.add_argument(
        '--graph_pooling_type',
        type=str,
        default="sum",
        choices=["sum", "average"],
        help='Pooling for over nodes in a graph: sum or average')
    parser.add_argument(
        '--neighbor_pooling_type',
        type=str,
        default="sum",
        choices=["sum", "average", "max"],
        help='Pooling for over neighboring nodes: sum, average or max')
    parser.add_argument(
        '--learn_eps',
        action="store_true",
        help=
        'Whether to learn the epsilon weighting for the center nodes. Does not affect training accuracy though.'
    )
    parser.add_argument(
        '--degree_as_tag',
        action="store_true",
        help=
        'let the input node features be the degree of nodes (heuristics for unlabeled graph)'
    )
    parser.add_argument('--filename', type=str, default="", help='output file')
    args = parser.parse_args()

    #set up seeds and gpu device
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    device = torch.device(
        "cuda:" +
        str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.seed)

    if args.dataset in {'DD', 'FRANKENSTEIN', 'NCI1', 'NCI109'}:
        graphs, num_classes = load_torch_data(args.dataset)
    else:
        graphs, num_classes = load_data(args.dataset, args.degree_as_tag)

    patience = 200
    best_result = 0
    best_std = 0
    best_dropout = None
    best_weight_decay = None
    best_lr = None
    best_time = 0
    best_epoch = 0

    #lr = [0.05]#, 0.01,0.002]#,0.01,
    #weight_decay = [1e-4]#,5e-4,5e-5, 5e-3] #5e-5,1e-4,5e-4,1e-3,5e-3
    #dropout = [0.2]#, 0.1, 0.2, 0.3, 0.4, 0.5 ,0.6, 0.7, 0.8, 0.9]
    #for args.lr, args.dropout in itertools.product(lr, dropout):
    #for args.lr, args.dropout in itertools.product(lr, dropout):
    result = np.zeros(10)
    t_total = time.time()
    num_epoch = 0
    for idx in range(10):

        ##10-fold cross validation. Conduct an experiment on the fold specified by args.fold_idx.
        #train_graphs, test_graphs = separate_data(graphs, args.seed, idx)
        train_graphs, val_graphs, test_graphs = rand_train_test_graph(graphs)

        model = GraphCNN(args.num_layers, args.num_mlp_layers,
                         train_graphs[0].node_features.shape[1],
                         args.hidden_dim, args.rank, num_classes, args.dropout,
                         args.learn_eps, args.graph_pooling_type,
                         args.neighbor_pooling_type, device).to(device)

        optimizer = optim.Adam(model.parameters(), lr=args.lr)
        scheduler = optim.lr_scheduler.StepLR(optimizer,
                                              step_size=50,
                                              gamma=0.5)
        tlss_mn = np.inf
        tacc_mx = 0.0
        curr_step = 0
        best_test = 0

        for epoch in range(1, args.epochs + 1):
            num_epoch = num_epoch + 1
            #scheduler.step()

            avg_loss = train(args, model, device, train_graphs, optimizer,
                             epoch)
            acc_train, acc_val, loss_val = validate(args, model, device,
                                                    train_graphs, val_graphs,
                                                    epoch)
            scheduler.step()

            if not args.filename == "":
                with open(args.filename, 'w') as f:
                    f.write("%f %f %f" % (avg_loss, acc_train, acc_test))
                    f.write("\n")
            print("")

            #if acc_val > tacc_mx or loss_val < tlss_mn:
            if acc_val > tacc_mx and loss_val < tlss_mn:
                best_test = test(args, model, device, test_graphs)
                #print(best_test)
                tacc_mx = acc_val
                tlss_mn = loss_val
                curr_step = 0
            else:
                curr_step += 1
                if curr_step >= patience:
                    break

            #if acc_train >= tacc_mx or avg_loss <= tlss_mn:
            #    if acc_train >= tacc_mx and avg_loss <= tlss_mn:
            #        best_test = acc_test
            #        best_training_loss = avg_loss
            #    tacc_mx = np.max((acc_train, tacc_mx))
            #    tlss_mn = np.min((avg_loss, tlss_mn))
            #    curr_step = 0
            #else:
            #    curr_step += 1
            #    if curr_step >= patience or np.isnan(avg_loss):
            #        break
            #best_test = acc_test

            #print(model.eps)
        print(best_test, args.lr, args.dropout)
        result[idx] = best_test
        del model, optimizer
        if torch.cuda.is_available(): torch.cuda.empty_cache()
    #five_epochtime = time.time() - t_total
    #print("Total time elapsed: {:.4f}s, Total Epoch: {:.4f}".format(five_epochtime, num_epoch))
    print(args.dataset, args.rank)
    print(
        "learning rate %.4f, dropout %.4f, Test Result: %.4f, Test Std: %.4f" %
        (args.lr, args.dropout, np.mean(result), np.std(result)))
Пример #26
0
        generator = nn.DataParallel(generator, device_ids=range(num_GPUS))

# ---------
# fine tune
# ---------
if args.finetune != 0:
    print('Fine tune...')
    lr = args.lr_finetune
    generator.freeze_ec_bn = True
else:
    lr = args.gen_lr

# ---------
# optimizer
# ---------
G_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(args.b1, args.b2))

writer = SummaryWriter(args.log_dir)

# --------
# Training
# --------
print('Start train...')
count = 0
for epoch in range(start_iter, args.train_epochs + 1):

    generator.train()
    for _, (input_images, ground_truths, masks) in enumerate(data_loader):

        count += 1
        start_time = time.time()
Пример #27
0
def models_train(X_train, y_train, X_valid, y_valid, X_test, y_test, configs, args):

	set_random_seed()

	# train data random shuffle
	p = np.random.permutation(X_train.shape[0])
	X_train = X_train[p]
	y_train = y_train[p]

	X_train = X_train.swapaxes(0, 1)
	X_valid = X_valid.swapaxes(0, 1)
	X_test = X_test.swapaxes(0, 1)

	if args.model == "mfn":
		model = MFN(*configs)
	elif args.model == "tman1":
		model = TMAN1(*configs)
	elif args.model == "marn":
		model = MARN(*configs)
	else:
		model = TMAN2(*configs)

	optimizer = optim.Adam(model.parameters(), lr=args.lr)

	criterion = nn.L1Loss()
	device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
	model = model.to(device)
	criterion = criterion.to(device)

	def train(model, batchsize, X_train, y_train, optimizer, criterion):
		epoch_loss = 0
		model.train()
		total_n = X_train.shape[1]
		num_batches = ceil(total_n // batchsize)
		for batch in range(num_batches):
			start = batch*batchsize
			end = (batch+1)*batchsize
			optimizer.zero_grad()
			batch_X = torch.Tensor(X_train[:, start:end]).cuda()
			batch_y = torch.Tensor(y_train[start:end]).squeeze(-1).cuda()
			predictions = model.forward(batch_X).squeeze(1)
			loss = criterion(predictions, batch_y)
			loss.backward()
			optimizer.step()
			epoch_loss += loss.item()
		return epoch_loss / num_batches

	def validate(model, X_valid, y_valid, criterion):
		model.eval()
		with torch.no_grad():
			batch_X = torch.Tensor(X_valid).cuda()
			batch_y = torch.Tensor(y_valid).squeeze(-1).cuda()
			predictions = model.forward(batch_X).squeeze(1)
			epoch_loss = criterion(predictions, batch_y).item()
		return epoch_loss

	def predict(model, X_test):
		model.eval()
		with torch.no_grad():
			batch_X = torch.Tensor(X_test).cuda()
			predictions = model.forward(batch_X).squeeze(1)
			predictions = predictions.cpu().data.numpy()
		return predictions

	best_valid_loss = 999999.0
	for epoch in trange(args.epoch):
		train_loss = train(model, args.bs, X_train, y_train, optimizer, criterion)
		valid_loss = validate(model, X_valid, y_valid, criterion)
		if valid_loss <= best_valid_loss:
			best_valid_loss = valid_loss
			print(str(epoch)+'th epoch', ', train loss is', train_loss, ', valid loss is', valid_loss)
			print('Saving model...')
			torch.save(model, 'saved_models/' + args.model + '_mmmo.pt')
		else:
			print(str(epoch) + 'th epoch', ', train loss is', train_loss, ', valid loss is', valid_loss)

	model = torch.load('saved_models/' + args.model + '_mmmo.pt')

	predictions = predict(model, X_test)
	y_test = np.squeeze(y_test)
	mae = np.mean(np.absolute(predictions-y_test))
	print("MAE:", mae)
	corr = np.corrcoef(predictions,y_test)[0][1]
	print("corr:", corr)
	true_label = (y_test > 3.5)
	predicted_label = (predictions > 3.5)
	print("Confusion Matrix:")
	print(confusion_matrix(true_label, predicted_label))
	bi_acc = accuracy_score(true_label, predicted_label)
	print("Accuracy:", bi_acc)
	f1 = round(f1_score(true_label, predicted_label, average='binary'), 5)
	print("F1 score:", f1)
Пример #28
0
    def forward(self, epoch, images, captions, lengths, save_images):
        self.iteration += 1

        # if epoch >  self.args.mixed_training_epoch:
        #     train_img = True
        #     train_txt = True
        # elif not epoch % 2:
        #     train_img = True
        #     train_txt = False
        # else:
        #     train_img = False
        #     train_txt = True

        train_txt = True
        train_img = False
        if not self.iteration % self.args.model_save_interval and not self.keep_loading:
            self.save_models(self.iteration)

        # Save samples
        if save_images:
            if self.use_variational:
                img2img_out, txt2img_out, img2txt_out, txt2txt_out, img_z, txt_z, img_mu, img_logv, txt_mu, txt_logv = \
                                 self.networks["coupled_vae"](images, captions , lengths)
            else:
                img2img_out, txt2img_out, img2txt_out, txt2txt_out, img_z, txt_z= \
                                 self.networks["coupled_vae"](images, captions , lengths)

            self.save_samples(images[0], img2img_out[0], txt2img_out[0], captions[0], txt2txt_out[0], img2txt_out[0])

        if self.iteration < self.args.vae_pretrain_iterations:
            self.train_G(epoch,images,captions,lengths,
                         pretrain= True, train_img=train_img, train_txt=train_txt)
            return


        if self.freeze_encoders and self.iteration == self.args.vae_pretrain_iterations:
            self.networks["coupled_vae"].encoder_rnn.requires_grad = False
            self.networks["coupled_vae"].encoder_cnn.requires_grad = False
            if self.use_variational:
                self.networks["coupled_vae"].hidden2mean_txt.requires_grad = False
                self.networks["coupled_vae"].hidden2mean_img.requires_grad = False
                self.networks["coupled_vae"].hidden2logv_txt.requires_grad = False
                self.networks["coupled_vae"].hidden2logv_img.requires_grad = False

            self.optimizers["coupled_vae"] = optim.Adam(valid_params(self.networks["coupled_vae"].parameters()),\
                                                    lr=self.args.learning_rate, betas=(0.5, 0.999)) #, weight_decay=0.00001)


        # self.epoch_max_len.fill_(5+int(epoch/2))
        # lengths = torch.min(self.epoch_max_len, lengths)

        # if self.iteration < self.args.vae_pretrain_iterations + self.args.disc_pretrain_iterations:
        #     cycle = 101
        # else:
        #     cycle = self.args.gan_cycle

        cycle = self.args.gan_cycle

        # cycle = 60

        # train_gen = self.iteration > 500 and self.iteration % 6
        if not self.iteration % cycle:
        # if not self.iteration % cycle or max(self.losses["Ls_G_img"].val,self.losses["Ls_G_txt"].val) > 0.75:
        # if self.iteration % cycle < 50:
             #
            for p in self.networks["img_discriminator"].parameters():  # reset requires_grad
                p.requires_grad = False  # they are set to False below in netG update
            for p in self.networks["txt_discriminator"].parameters():  # reset requires_grad
                p.requires_grad = False  # they are set to False below in netG update

            self.train_G(epoch, images, captions, lengths, train_img = train_img,
                         train_txt= train_txt)
            # self.optimizers["coupled_vae"].step()
            #
            for p in self.networks["img_discriminator"].parameters():  # reset requires_grad
                p.requires_grad = True  # they are set to False below in netG update
            for p in self.networks["txt_discriminator"].parameters():  # reset requires_grad
                p.requires_grad = True  # they are set to False below in netG update
        else:

            self.train_D(epoch, images, captions, lengths, train_img, train_txt)
Пример #29
0
    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


print('Building model...')
model = VAE().to(device)

print('Intializing optimizer...')
optimizer = optim.Adam(model.parameters(), lr=1e-3)


# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction="sum")

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + args.beta * KLD

Пример #30
0
def create_trainer(
    config_dict: Dict[str, Any],
    output: Path,
):
    # config
    config = Config.from_dict(config_dict)
    config.add_git_info()

    output.mkdir(exist_ok=True, parents=True)
    with (output / "config.yaml").open(mode="w") as f:
        yaml.safe_dump(config.to_dict(), f)

    # model
    predictor = create_predictor(config.network)
    model = Model(model_config=config.model, predictor=predictor)
    if config.train.weight_initializer is not None:
        init_weights(model, name=config.train.weight_initializer)

    device = torch.device("cuda")
    model.to(device)

    # dataset
    _create_iterator = partial(
        create_iterator,
        batch_size=config.train.batch_size,
        num_processes=config.train.num_processes,
        use_multithread=config.train.use_multithread,
    )

    datasets = create_dataset(config.dataset)
    train_iter = _create_iterator(datasets["train"], for_train=True)
    test_iter = _create_iterator(datasets["test"], for_train=False)
    eval_iter = _create_iterator(datasets["test"], for_train=False, for_eval=True)

    warnings.simplefilter("error", MultiprocessIterator.TimeoutWarning)

    # optimizer
    cp: Dict[str, Any] = copy(config.train.optimizer)
    n = cp.pop("name").lower()

    optimizer: Optimizer
    if n == "adam":
        optimizer = optim.Adam(model.parameters(), **cp)
    elif n == "sgd":
        optimizer = optim.SGD(model.parameters(), **cp)
    else:
        raise ValueError(n)

    # updater
    updater = StandardUpdater(
        iterator=train_iter,
        optimizer=optimizer,
        model=model,
        device=device,
    )

    # trainer
    trigger_log = (config.train.log_iteration, "iteration")
    trigger_eval = (config.train.snapshot_iteration, "iteration")
    trigger_stop = (
        (config.train.stop_iteration, "iteration")
        if config.train.stop_iteration is not None
        else None
    )

    trainer = Trainer(updater, stop_trigger=trigger_stop, out=output)
    writer = SummaryWriter(Path(output))

    ext = extensions.Evaluator(test_iter, model, device=device)
    trainer.extend(ext, name="test", trigger=trigger_log)

    generator = Generator(
        config=config,
        predictor=predictor,
        use_gpu=True,
    )
    generate_evaluator = GenerateEvaluator(
        generator=generator,
    )
    ext = extensions.Evaluator(eval_iter, generate_evaluator, device=device)
    trainer.extend(ext, name="eval", trigger=trigger_eval)

    saving_model_num = 0
    if config.train.stop_iteration is not None:
        saving_model_num = int(
            config.train.stop_iteration / config.train.snapshot_iteration / 10
        )
    saving_model_num = max(saving_model_num, 5)

    ext = extensions.snapshot_object(
        predictor,
        filename="predictor_{.updater.iteration}.pth",
        n_retains=saving_model_num,
    )
    trainer.extend(
        ext,
        trigger=LowValueTrigger("eval/main/f0_diff", trigger=trigger_eval),
    )

    trainer.extend(extensions.FailOnNonNumber(), trigger=trigger_log)
    trainer.extend(extensions.observe_lr(), trigger=trigger_log)
    trainer.extend(extensions.LogReport(trigger=trigger_log))
    trainer.extend(
        extensions.PrintReport(["iteration", "main/loss", "test/main/loss"]),
        trigger=trigger_log,
    )

    ext = TensorboardReport(writer=writer)
    trainer.extend(ext, trigger=trigger_log)

    if config.project.category is not None:
        ext = WandbReport(
            config_dict=config.to_dict(),
            project_category=config.project.category,
            project_name=config.project.name,
            output_dir=output.joinpath("wandb"),
        )
        trainer.extend(ext, trigger=trigger_log)

    (output / "struct.txt").write_text(repr(model))

    if trigger_stop is not None:
        trainer.extend(extensions.ProgressBar(trigger_stop))

    ext = extensions.snapshot_object(
        trainer,
        filename="trainer_{.updater.iteration}.pth",
        n_retains=1,
        autoload=True,
    )
    trainer.extend(ext, trigger=trigger_eval)

    return trainer