Ejemplo n.º 1
0
    def train(self):
        """Cycles through alternately training the shared parameters and the
        controller, as described in Section2.4 Training ENAS and deriving
        Architectures, of the paraer.
        """
        if self.args.shared_initial_step > 0:
            self.train_shared(self.args.shared_initial_step)
            self.train_controller()

        for self.epoch in range(self.start_epoch, self.args.max_epoch):
            # 1. Training the shared parameters omega of the child models
            self.train_shared()

            # 2. Training the controller parameters theta
            #self.train_controller()
            if self.epoch == 0:
                with _get_no_grad_ctx_mgr():
                    best_dag = self.derive()
                    self.evaluate(iter(self.test_data),
                                  best_dag,
                                  'val_best',
                                  max_num=self.args.batch_size * 100)
                self.save_model()

            if self.epoch % self.args.save_epoch == 0:
                with _get_no_grad_ctx_mgr():
                    best_dag = self.derive()
                    self.evaluate(iter(self.test_data),
                                  best_dag,
                                  'val_best',
                                  max_num=self.args.batch_size * 100)
                self.save_model()

            if self.epoch >= self.args.shared_decay_after:
                utils.update_lr(self.shared_optim, self.shared_lr)
Ejemplo n.º 2
0
    def train(self):
        """Cycles through alternately training the shared parameters and the
        controller, as described in Section 2.2, Training ENAS and Deriving
        Architectures, of the paper.

        From the paper (for Penn Treebank):

        - In the first phase, shared parameters omega are trained for 400
          steps, each on a minibatch of 64 examples.

        - In the second phase, the controller's parameters are trained for 2000
          steps.
        """
        if self.args.shared_initial_step > 0:
            self.train_shared(self.args.shared_initial_step)
            self.train_controller()

        for self.epoch in range(self.start_epoch, self.args.max_epoch):
            # 1. Training the shared parameters omega of the child models
            self.train_shared()

            # 2. Training the controller parameters theta
            self.train_controller()

            if self.epoch % self.args.save_epoch == 0:
                with _get_no_grad_ctx_mgr():
                    best_dag = self.derive()
                    self.evaluate(self.eval_data,
                                  best_dag,
                                  'val_best',
                                  max_num=self.args.batch_size * 100)
                self.save_model()

            if self.epoch >= self.args.shared_decay_after:
                utils.update_lr(self.shared_optim, self.shared_lr)
Ejemplo n.º 3
0
    def train(self, epoch, train_dataloader):
        self.model.train()
        lr = utils.update_lr(epoch, self.cfg_stg, self.optimizer)
        total_loss1 = 0
        total_loss2 = 0
        batch_num = 0
        for ang, pos, ori in train_dataloader:
            ang, pos, ori = ang.to(self.device), pos.to(self.device), ori.to(self.device)

            pred_pos, pred_ori_pow2 = self.model(ang)

            loss1, loss2 = self.criterion(pos, ori, pred_pos, pred_ori_pow2)
            loss = loss1 + self.cfg_stg['loss_weight_ori'] * loss2

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            total_loss1 += loss1.item()
            total_loss2 += loss2.item()
            batch_num += 1

        self.tb_logger.add_scalar('train_loss_pos', total_loss1 / batch_num, epoch)
        self.tb_logger.add_scalar('train_loss_ori', total_loss2 / batch_num, epoch)
        if epoch % 1 == 0:
            self.logger.info('Train: epoch {}, lr {:.6f}, loss_pos {:.6f}, loss_ori {:.6f}'.format(
                epoch, lr, total_loss1/batch_num, total_loss2/batch_num))
Ejemplo n.º 4
0
def train():
    autoencoder.train()
    epochs = [1, 5, 10]
    for epoch in range(epochs[-1]):
        running_loss = 0.0
        progress_bar.newbar(len(trainloader))
        for batch_idx, (inputs, _) in enumerate(trainloader):
            with chrono.measure("step_time"):
                inputs = get_torch_vars(inputs)

                lr = update_lr(optimizer, epoch, epochs, 0.003, batch_idx,
                               len(trainloader))
                if lr is None:
                    break

                _, decoded = autoencoder(inputs)
                loss = criterion(decoded, inputs)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                running_loss += loss.data

            msg = 'Step: %s | Tot: %s | LR: %.10f | Loss: %.3f' % \
                  (Utils.format_time(chrono.last('step_time')),
                   Utils.format_time(chrono.total('step_time')),
                   lr,
                   running_loss / (batch_idx + 1))
            progress_bar.update(batch_idx, msg)

        chrono.remove("step_time")
Ejemplo n.º 5
0
    def train(self, single=False):
        """Cycles through alternately training the shared parameters and the
        controller, as described in Section 2.2, Training ENAS and Deriving
        Architectures, of the paper.

        From the paper (for Penn Treebank):

        - In the first phase, shared parameters omega are trained for 400
          steps, each on a minibatch of 64 examples.

        - In the second phase, the controller's parameters are trained for 2000
          steps.
          
        Args:
            single (bool): If True it won't train the controller and use the
                           same dag instead of derive().
        """
        dag = utils.load_dag(self.args) if single else None  # 初始训练dag=None

        if self.args.shared_initial_step > 0:  # self.args.shared_initial_step default=0
            self.train_shared(self.args.shared_initial_step)
            self.train_controller()

        for self.epoch in range(
                self.start_epoch,
                self.args.max_epoch):  # start_epoch=0,max_epoch=150
            # 1. Training the shared parameters omega of the child models
            # 训练RNN,先用Controller随机生成一个dag,然后用这个dag构建一个RNNcell,然后用这个RNNcell去做下一个词预测,得到loss
            self.train_shared(dag=dag)

            # 2. Training the controller parameters theta
            if not single:
                self.train_controller()

            if self.epoch % self.args.save_epoch == 0:
                with _get_no_grad_ctx_mgr():
                    best_dag = dag if dag else self.derive()
                    self.evaluate(self.eval_data,
                                  best_dag,
                                  'val_best',
                                  max_num=self.args.batch_size * 100)
                self.save_model()
            #应该是逐渐降低学习率
            if self.epoch >= self.args.shared_decay_after:
                utils.update_lr(self.shared_optim, self.shared_lr)
Ejemplo n.º 6
0
    def train(self, single=False):
        """Cycles through alternately training the shared parameters and the
        controller, as described in Section 2.2, Training ENAS and Deriving
        Architectures, of the paper.

        From the paper (for Penn Treebank):

        - In the first phase, shared parameters omega are trained for 400
          steps, each on a minibatch of 64 examples.

        - In the second phase, the controller's parameters are trained for 2000
          steps.
          
        Args:
            single (bool): If True it won't train the controller and use the
                           same dag instead of derive().
        """
        self.baseline = None

        dag = utils.load_dag(self.args, self.logger) if single else None

        if self.args.shared_initial_step > 0:
            self.train_shared(self.args.shared_initial_step)
            self.train_controller()

        for self.epoch in range(self.start_epoch, self.args.max_epoch):
            # 1. Training the shared parameters omega of the child models
            self.train_shared(dag=dag)

            # 2. Training the controller parameters theta
            if not single:
                self.train_controller()

            if self.epoch % self.args.save_epoch == 0 and self.epoch != 0:
                with _get_no_grad_ctx_mgr():
                    best_dag = dag if dag else self.derive()
                    self.evaluate(best_dag, batch_size=self.args.batch_size)
                self.save_model()

            if self.epoch >= self.args.shared_decay_after:
                utils.update_lr(self.shared_optim, self.shared_lr)
        self.save_model()
        self.dag_file.close()
Ejemplo n.º 7
0
    def train(self):
        for self.epoch in range(self.args.max_epoch):
            # 400 steps, each on a minibatch of 64 examples.
            print(f'train_shared, cur_shared_step: {self.shared_step}')
            # self.train_shared()

            #2000 steps, each on a minibatch of 1 examples.
            print(
                f'train_controller, cur_controller_step: {self.controller_step}'
            )
            # self.train_controller()

            if self.epoch % self.args.save_epoch == self.args.save_epoch - 1:
                with torch.no_grad():
                    best_dag = self.derive()
                    ppl = self.evaluate(self.eval_data, best_dag)
                    print(f'ppl: {ppl}')

            if self.epoch >= self.args.shared_decay_after:
                utils.update_lr(self.shared_optim, self.shared_lr)
Ejemplo n.º 8
0
    def train(self):
        self.model.train()
        self.train_loss = 0
        correct = 0
        total = 0
        self.pred = []
        self.progress_bar.newbar(len(self.trainloader))
        for batch_idx, (inputs, targets) in enumerate(self.trainloader):
            with self.chrono.measure("step_time"):
                inputs = get_torch_vars(inputs)
                targets = get_torch_vars(targets)

                self.lr = update_lr(self.optimizer,
                                    self.epoch, self.epochs,
                                    self.initial_lr,
                                    batch_idx, len(self.trainloader))
                if self.lr is None:
                    break

                self.optimizer.zero_grad()
                outputs = self.model(inputs)
                loss = self.criterion(outputs.double(), targets.double())
                loss.backward()
                self.optimizer.step()

                self.train_loss += loss.item()
                predicted = (outputs + 0.5).int()
                total += targets.size(0)
                correct += (predicted == targets).sum().item()
                self.pred.append(outputs.cpu().data.numpy())

            msg = self.step_msg % (Utils.format_time(self.chrono.last('step_time')),
                                   Utils.format_time(self.chrono.total('step_time')),
                                   self.lr,
                                   self.train_loss / (batch_idx + 1),
                                   100. * correct / total,
                                   correct,
                                   total)
            self.progress_bar.update(batch_idx, msg)

        self.chrono.remove("step_time")
        self.train_acc = 100. * correct / total
Ejemplo n.º 9
0
    def train(self):
        self.ae.eval()
        self.model.train()
        self.train_loss = 0
        correct = 0
        total = 0
        self.progress_bar.newbar(len(self.trainloader))
        for batch_idx, (inputs, targets) in enumerate(self.trainloader):
            with self.chrono.measure("step_time"):
                inputs = get_torch_vars(inputs)
                targets = get_torch_vars(targets)

                self.lr = update_lr(self.optimizer, self.epoch, self.epochs,
                                    self.initial_lr, batch_idx,
                                    len(self.trainloader))
                if self.lr is None:
                    break

                self.optimizer.zero_grad()
                encoded, _ = self.ae(inputs)
                outputs = self.model(encoded)
                loss = self.criterion(outputs, targets)
                loss.backward()
                self.optimizer.step()

                self.train_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

            msg = self.step_msg % (
                Utils.format_time(self.chrono.last('step_time')),
                Utils.format_time(
                    self.chrono.total('step_time')), self.lr, self.train_loss /
                (batch_idx + 1), 100. * correct / total, correct, total)
            self.progress_bar.update(batch_idx, msg)

        self.chrono.remove("step_time")
        self.train_acc = 100. * correct / total
Ejemplo n.º 10
0
            if opt.debug:
                break

        if epoch % opt.epoch_save == 0:
            torch.save(
                G.state_dict(),
                os.path.join(test_opt.model_dir, "{:d}_G.pt".format(epoch)))
            torch.save(
                D.state_dict(),
                os.path.join(test_opt.model_dir, "{:d}_D.pt".format(epoch)))

            image_dir = os.path.join(test_opt.image_dir, "{:d}".format(epoch))
            os.makedirs(image_dir, exist_ok=True)
            with torch.no_grad():
                for i, (input, target) in enumerate(test_data_loader):
                    input = input.to(DEVICE)
                    fake = G(input)
                    manager.save_image(
                        fake, os.path.join(image_dir,
                                           "{:d}_fake.png".format(i)))
                    manager.save_image(
                        target,
                        os.path.join(image_dir, "{:d}_real.png".format(i)))

        if epoch > opt.epoch_decay:
            lr = update_lr(opt.lr, lr, opt.n_epochs - opt.epoch_decay, D_optim,
                           G_optim)

    print("Total time taken: ", datetime.datetime.now() - start_time)
Ejemplo n.º 11
0
            model = Stacked_mLSTM(mLSTM, layers, embed_size, rnn_size,
                                  data_size, dropout)

        loss_fn = nn.CrossEntropyLoss()
        embed_optimizer = optim.Adam(embedding.parameters(), lr=lr)
        model_optimizer = optim.Adam(model.parameters(), lr=lr)

        n_params = sum([p.nelement() for p in model.parameters()])
        print('Total number of parameters:', n_params)
        print('Total number of batches:', num_batches)
        print()
        print('Embedding Summary:')
        print(embedding)
        print()
        print('RNN Summary:')
        print(model)

        for e in range(int(options.epochs)):
            try:
                train_model(e)
                lr *= 0.7
                utils.update_lr(model_optimizer, lr)
                utils.update_lr(embed_optimizer, lr)
            except KeyboardInterrupt:
                print('KeyboardInterrupt occured, saving the model')
                utils.save_model(options, model, embedding, e)
                writer.export_scalars_to_json("./all_scalars.json")
                writer.close()
                break
        utils.save_model(options, model, embedding, e)
def main():
    dataset_path=r'/home/*/f3.hdf5'# 1 or 1/2 dataset
    dataset_nl_path=r'/home/*/hdf5/f2.hdf5'
    model_save_path=r'/home/*/model/model.ckpt'
    
    learning_rate_G=2.5e-4
    learning_rate_D=1e-4
    class_number=5
    
    batch_size=8
    max_iters=20000
    start_semi=5000
    mask_T=0.2
    lambda_abv=0.01
    lambde_semi=0.1
    
    dataset_param=batch_data.Data(dataset_path)
    iter_num=dataset_param.img_num//batch_size
    print(iter_num)
    dataset_nl_param=batch_data.data(dataset_nl_path)
    
    # G_net placeholder
    image = tf.placeholder(tf.float32, shape = [None,None,None,3],name='image')# have label or no label
    label = tf.placeholder(tf.float32, shape = [None,None,None,class_number],name='label')
    is_training = tf.placeholder(tf.bool,name='is_training')
    
    G_score, G_softmax, end_points = build_G(image,class_number,is_training,False,True)
    #print(end_points)
    L2_loss=tf.losses.get_regularization_loss()
#    init_fn=slim.assign_from_checkpoint_fn(r'/home/*/model.ckpt',slim.get_model_variables('u_net'))
    
    # build D two time
    D_score_fake, D_sigmoid_fake = build_D(G_softmax)
    ''' Core loss: Loss_Seg_Adv,  Loss_Semi, Loss_D '''
    
    with tf.name_scope('loss_g'):
        Loss_Seg_Adv=soft_loss(logits=G_score,labels=label)+lambda_abv*sig_loss(D_score_fake,True)+L2_loss#######
        tf.summary.scalar('loss_g',Loss_Seg_Adv)
    
    Loss_Semi=lambde_semi*nl_soft_loss(G_score,D_sigmoid_fake,class_number,mask_T)###################
        
    Loss_D_fake=sig_loss(D_score_fake,False)
    
    D_score_real, D_sigmoid_real = build_D(label,reuse=True)
    Loss_D_real=sig_loss(D_score_real,True)
    
    with tf.name_scope('loss_d'):
        Loss_D=Loss_D_fake+Loss_D_real#############
        tf.summary.scalar('loss_d',Loss_D)
    # get all g d vars list
    all_vars=tf.trainable_variables()
    g_vars=[var for var in all_vars if 'u_net' in var.name]
    d_vars=[var for var in all_vars if 'FCDiscriminator' in var.name]
    
    ''' adjust lr ; loss_semi with loss_seg_adv same lr 
        because it's a graph , global_step is a Variable ,update same time.
    '''
    global_step_G=tf.Variable(tf.constant(0))
    lr_g=update_lr(learning_rate_G,max_iters//4,0.1,global_step_G)
    train_op_G=update_optim(Loss_Seg_Adv,lr_g,g_vars,global_step_G)#训练时,只更新输入的g_vars,d_vars不更新。因此,计算semi_loss时,也不更新D
    
    #lr_g=update_lr(learning_rate_G,max_iters//4,0.1,global_step_G)
    train_op_Semi=update_optim(Loss_Semi,lr_g,g_vars) # loss_semi's lr = loss_seg_adv's lr
    
    global_step_D=tf.Variable(tf.constant(0))
    lr_d=update_lr(learning_rate_D,max_iters//4,0.1,global_step_D)
    train_op_D=update_optim(Loss_D,lr_d,d_vars,global_step_D)
    
    sess=tf.Session()
    merged=tf.summary.merge_all()
    writer=tf.summary.FileWriter('./model/',sess.graph)
    
    init_op=tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
    sess.run(init_op)
    saver=tf.train.Saver(max_to_keep=1)
    
#    init_fn(sess) # fine-tune
    continue_learning=False
    if continue_learning:
        ckpt=tf.train.get_checkpoint_state(r'/home/*/model')
        saver.restore(sess,ckpt.model_checkpoint_path)
        
    for iters in range(max_iters):
        loss_semi=0
        if iters >= start_semi:
            imgs_nl=dataset_nl_param.next_batch(batch_size)
            sess.run(train_op_Semi,feed_dict={image:imgs_nl,is_training:True})# no label ,generate by itself
            loss_semi += sess.run(Loss_Semi,feed_dict={image:imgs_nl,is_training:True})
        imgs,labs=dataset_param.next_batch(batch_size)
        sess.run(train_op_G,feed_dict={image:imgs,label:labs,is_training:True})
        sess.run(train_op_D,feed_dict={image:imgs,label:labs,is_training:True})
        # add loss to tensorboard
        result=sess.run(merged,feed_dict={image:imgs,label:labs,is_training:True})
        writer.add_summary(result,iters)
        
        loss_seg_adv, loss_d = sess.run([Loss_Seg_Adv,Loss_D],feed_dict={image:imgs,label:labs,is_training:True})
        
        #print('loss_seg_adv: %.2f ,loss_d: %.2f ,loss_semi: %.2f '%(loss_seg_adv,loss_d,loss_semi))
        print('\rloss_seg_adv: %.2f ,loss_d: %.2f ,loss_semi: %.2f '%(loss_seg_adv,loss_d,loss_semi),end='',flush=True)
        
        if iters%1000==0 and iters!=0:
            saver.save(sess,save_path=model_save_path,global_step=iters)
Ejemplo n.º 13
0
                cc_bin_2x2 = np.mean(list_cc_bin_2x2)
                cc_bin_4x4 = np.mean(list_cc_bin_4x4)
                cc_bin_8x8 = np.mean(list_cc_bin_8x8)

                R1_mean = np.mean(list_R1)
                R1_std = np.std(list_R1)
                R2_mean = np.mean(list_R2)
                R2_std = np.std(list_R2)

                with open(os.path.join(test_model_dir, 'Analysis.txt'),
                          'a') as analysis:
                    analysis.write(
                        str(current_step) + ', ' + str(cc_TUMF[0][1]) + ', ' +
                        str(cc_1x1) + ', ' + str(cc_bin_2x2) + ', ' +
                        str(cc_bin_4x4) + ', ' + str(cc_bin_8x8) + ', ' +
                        str(R1_mean) + ', ' + str(R1_std) + ', ' +
                        str(R2_mean) + ', ' + str(R2_std) + '\n')
                    analysis.close()

                for p in G.parameters():
                    p.requires_grad_(True)

            if opt.debug:
                break

        if epoch > opt.epoch_decay and opt.HD:
            lr = update_lr(lr, init_lr, opt.n_epochs - opt.epoch_decay,
                           D_optim, G_optim)

    print("Total time taken: ", datetime.datetime.now() - start_time)
Ejemplo n.º 14
0
def train(opt, netG, netD, optim_G, optim_D):
    tensor = torch.cuda.FloatTensor
    # lossD_list = []
    # lossG_list = []

    train = ReadConcat(opt)
    trainset = DataLoader(train, batch_size=opt.batchSize, shuffle=True)
    save_img_path = os.path.join('./result', 'train')
    check_folder(save_img_path)

    for e in range(opt.epoch, opt.niter + opt.niter_decay + 1):
        for i, data in enumerate(trainset):
            # set input
            data_A = data['A']  # blur
            data_B = data['B']  #sharp
            # plt.imshow(image_recovery(data_A.squeeze().numpy()))
            # plt.pause(0)
            # print(data_A.shape)
            # print(data_B.shape)

            if torch.cuda.is_available():
                data_A = data_A.cuda(opt.gpu)
                data_B = data_B.cuda(opt.gpu)
            # forward
            realA = Variable(data_A)
            fakeB = netG(realA)
            realB = Variable(data_B)

            # optimize_parameters
            # optimizer netD
            set_requires_grad([netD], True)
            for iter_d in range(1):
                optim_D.zero_grad()
                loss_D, _ = get_loss(tensor, netD, realA, fakeB, realB)
                loss_D.backward()
                optim_D.step()

            # optimizer netG
            set_requires_grad([netD], False)
            optim_G.zero_grad()
            _, loss_G = get_loss(tensor, netD, realA, fakeB, realB)
            loss_G.backward()
            optim_G.step()
            if i % 50 == 0:
                # lossD_list.append(loss_D)
                # lossG_list.append(loss_G)
                print('{}/{}: lossD:{}, lossG:{}'.format(i, e, loss_D, loss_G))

        visul_img = torch.cat((realA, fakeB, realA), 3)
        #print(type(visul_img), visul_img.size())
        visul_img = image_recovery(visul_img)
        #print(visul_img.size)
        save_image(visul_img,
                   os.path.join(save_img_path, 'epoch' + str(e) + '.png'))

        if e > opt.niter:
            update_lr(optim_D, opt.lr, opt.niter_decay)
            lr = (optim_G, opt.lr, opt.niter_decay)
            opt.lr = lr

        if e % opt.save_epoch_freq == 0:
            save_net(netG, opt.checkpoints_dir, 'G', e)
            save_net(netD, opt.checkpoints_dir, 'D', e)
Ejemplo n.º 15
0
    def train(self, single=False):
        """Cycles through alternately training the shared parameters and the
        controller, as described in Section 2.2, Training ENAS and Deriving
        Architectures, of the paper.

        From the paper (for Penn Treebank):

        - In the first phase, shared parameters omega are trained for 400
          steps, each on a minibatch of 64 examples.

        - In the second phase, the controller's parameters are trained for 2000
          steps.
          
        Args:
            single (bool): If True it won't train the controller and use the
                           same dag instead of derive().
        """
        shared_train_times = []
        controller_train_times = []
        dag = utils.load_dag(self.args) if single else None
        self.shared.forward_evals = 0
        if self.args.shared_initial_step > 0:
            self.train_shared(self.args.shared_initial_step)
            self.train_controller()

        for self.epoch in range(self.start_epoch, self.args.max_epoch):
            # 1. Training the shared parameters omega of the child models
            start_time = time.time()
            self.train_shared()
            shared_train_time = time.time() - start_time
            shared_train_times.append(shared_train_time)
            logger.info(
                f'>>> train_shared() time: {shared_train_time} Epoch: {self.epoch}'
            )

            # 2. Training the controller parameters theta
            if not single:
                start_time = time.time()
                self.train_controller()
                controller_train_time = time.time() - start_time
                controller_train_times.append(controller_train_time)
                logger.info(
                    f'>>> train_controller() time: {shared_train_time} Epoch: {self.epoch}'
                )

            if self.epoch % self.args.save_epoch == 0:
                with _get_no_grad_ctx_mgr():
                    best_dag = dag if dag else self.derive()
                    loss, ppl = self.evaluate(self.eval_data,
                                              best_dag,
                                              'val_best',
                                              max_num=self.args.batch_size *
                                              100)
                    # PT: we could annotate best_dag with the following:
                    #best_dag["ppl"]  = ppl
                    #best_dag["loss"] = loss
                    if ppl < self.best_ppl:
                        self.best_ppl = ppl
                        self.best_evaluated_dag = best_dag
                        self.best_epoch = self.epoch
                self.save_model()
            #######################################################################
            #(PT)
            #MISSING: Best (highest reward) child model needs to be re-trained from
            #scratch here and evaluated for perplexity on the validation set
            #######################################################################
            if (self.args.train_best):
                logger.info('>> train_shared(1000, best_dag)')
                self.train_shared(2000, best_dag)
                logger.info('<< finished training best_dag')

            if self.epoch >= self.args.shared_decay_after:
                utils.update_lr(self.shared_optim, self.shared_lr)
        self.save_dag(self.best_evaluated_dag)
        logger.info(f'BEFORE RETRAINING BEST DAG:')
        logger.info(f'Best Dag: {self.best_evaluated_dag}')
        logger.info(f'Found in epoch: {self.best_epoch}')
        logger.info(f'With perplexity: {self.best_ppl}')
        logger.info(f'AFTER RETRAINING BEST DAG:')
        logger.info(
            '>> Final evaluation: train_shared(2000, best_evaluated_dag)')
        self.train_shared(2000, self.best_evaluated_dag)
        logger.info('<< finished training best_evaluated_dag')
        self.save_shared()
        shared_train_time_variance = np.var(shared_train_times)
        logger.info(
            f'Shared Training time variance: {shared_train_time_variance}')
        controller_train_time_variance = np.var(controller_train_times)
        logger.info(
            f'Controller Training time variance: {controller_train_time_variance}'
        )
        logger.info(f'shared train times: {shared_train_times}')
        logger.info(f'controller train times: {controller_train_times}')
Ejemplo n.º 16
0
        loss = loss_NMD1 + loss_NMD2*0.1 + loss_NMD3
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if (i+1) % 10 == 0:
            acc_alpha = evaluate_NMD(model, alpha, 'alpha')
            acc_sigma = evaluate_NMD(model, sigma, 'sigma')
            
            if alpha < 0.8 and acc_alpha >= 0.95:
                alpha += 0.1
                torch.save(model.state_dict(), f'./models/weights/NMD_alpha_{alpha:.1f}_sigma_{sigma:.4f}.pth')
                
            if sigma > 0.0044 and acc_sigma >= 0.95:
                sigma *= 0.8
                torch.save(model.state_dict(), f'./models/weights/NMD_alpha_{alpha:.1f}_sigma_{sigma:.4f}.pth')

            etime = time.time() - stime
            rtime = etime * (total_epoch_iter-iter_count) / (iter_count+eps)
            print(f'Epoch: {epoch+1:04d}/{num_epochs}, Iter: {i+1:03d}/{total_iter}, Loss: {loss.data:.4f},', end=' ')
            print(f'Acc alpha: {acc_alpha:.2f}, alpha: {alpha:.1f}, Acc sigma: {acc_sigma:.2f}, sigma: {sigma:.4f},', end=' ')
            print(f'Elapsed: {sec2time(etime)}, Remaining: {sec2time(rtime)}')
            
        if alpha >= 0.8 and sigma <= 0.0044:
            torch.save(model.state_dict(), f'./models/weights/NMD.pth')
            break
            
    if (epoch+1) % 10 == 0:
        current_lr *= 0.5
        update_lr(optimizer, current_lr)
Ejemplo n.º 17
0
        summary.add_scalar(f'loss G/loss Overall',
                           loss_overall.data.cpu().numpy(), iter_count)

        etime = time.time() - stime
        rtime = etime * (total_epoch_iter - iter_count) / (iter_count + eps)
        print(
            f'Epoch: {epoch+1:03d}/{num_epochs:03d}, Iter: {i+1:04d}/{total_iter:04d}, ',
            end='')
        print(f'Loss G: {loss_overall.data:.4f}, Loss D: {loss_D.data:.4f}, ',
              end='')
        print(f'Elapsed: {sec2time(etime)}, Remaining: {sec2time(rtime)}')

        if (i + 1) % 10 == 0:
            summary.add_image(f'image/sr_image', sr[0], iter_count)
            summary.add_image(f'image/lr_image', lr[0], iter_count)
            summary.add_image(f'image/hr_image', hr[0], iter_count)

    torch.save(
        G.state_dict(),
        f'./models/weights/G_epoch_{epoch+1}_loss_{loss_overall.data:.4f}.pth')
    torch.save(
        D.state_dict(),
        f'./models/weights/D_epoch_{epoch+1}_loss_{loss_D.data:.4f}.pth')

    if (epoch + 1) % 10 == 0:
        learning_rateG *= 0.5
        learning_rateD *= 0.5

    update_lr(optimizerG, learning_rateG)
    update_lr(optimizerD, learning_rateD)
Ejemplo n.º 18
0
discriminator.cuda()
unet.cuda()
EPOCH = 100
num_iter = len(train_loader)
D_LOSS = []
G_LOSS = []
# S_LOSS=[]
f = open("./loss_gan.txt", 'a')
print(time.strftime('|---------%Y-%m-%d   %H:%M:%S---------|',
                    time.localtime(time.time())),
      file=f)
discriminator.train()
unet.train()
for epoch in range(EPOCH):
    if epoch == 30:
        update_lr(optimizer_g, 0.0001)
        update_lr(optimizer_d, 0.0001)
        update_lr(optimizer_s, 0.0001)
        print('change lr to :', optimizer_g.param_groups[0]['lr'])
    elif epoch == 60:
        update_lr(optimizer_g, 0.00005)
        update_lr(optimizer_d, 0.00005)
        update_lr(optimizer_s, 0.00005)
        print('change lr to :', optimizer_g.param_groups[0]['lr'])
    elif epoch == 90:
        update_lr(optimizer_g, 0.00001)
        update_lr(optimizer_d, 0.00001)
        update_lr(optimizer_s, 0.00001)
        print('change lr to :', optimizer_g.param_groups[0]['lr'])
    d_loss_ = 0
    g_loss_ = 0
Ejemplo n.º 19
0
    def train(self):
        data_iter = iter(self.train_dataloader)

        if self.train_config.resume_checkpoint:
            start = self.resume_step + 1
        else:
            start = 0

        moving_max_grad = 0
        moving_grad_moment = 0.999
        max_grad = 0

        for step in range(start, self.train_config.total_step + 1):
            try:
                image_dict = next(data_iter)
            except:
                data_iter = iter(self.train_dataloader)
                image_dict = next(data_iter)

            image, alpha, trimap, mask = image_dict['image'], image_dict[
                'alpha'], image_dict['trimap'], image_dict['mask']
            image = image.cuda()
            alpha = alpha.cuda()
            trimap = trimap.cuda()
            mask = mask.cuda()
            fg_norm, bg_norm = image_dict['fg'].cuda(), image_dict['bg'].cuda()
            # train() of DistributedDataParallel has no return
            self.G.train()
            log_info = ""
            loss = 0
            """===== Update Learning Rate ====="""
            if step < self.train_config.warmup_step and self.train_config.resume_checkpoint is None:
                cur_G_lr = utils.warmup_lr(self.train_config.G_lr, step + 1,
                                           self.train_config.warmup_step)
                utils.update_lr(cur_G_lr, self.G_optimizer)

            else:
                self.G_scheduler.step()
                cur_G_lr = self.G_scheduler.get_lr()[0]
            """===== Forward G ====="""

            pred = self.G(image, mask)
            alpha_pred_os1, alpha_pred_os4, alpha_pred_os8 = pred[
                'alpha_os1'], pred['alpha_os4'], pred['alpha_os8']

            weight_os8 = utils.get_unknown_tensor(trimap)
            weight_os8[...] = 1

            flag = False
            if step < self.train_config.warmup_step:
                flag = True
                weight_os4 = utils.get_unknown_tensor(trimap)
                weight_os1 = utils.get_unknown_tensor(trimap)
            elif step < self.train_config.warmup_step * 3:
                if random.randint(0, 1) == 0:
                    flag = True
                    weight_os4 = utils.get_unknown_tensor(trimap)
                    weight_os1 = utils.get_unknown_tensor(trimap)
                else:
                    weight_os4 = utils.get_unknown_tensor_from_pred(
                        alpha_pred_os8,
                        rand_width=CONFIG.model.self_refine_width1,
                        train_mode=True)
                    alpha_pred_os4[weight_os4 == 0] = alpha_pred_os8[weight_os4
                                                                     == 0]
                    weight_os1 = utils.get_unknown_tensor_from_pred(
                        alpha_pred_os4,
                        rand_width=CONFIG.model.self_refine_width2,
                        train_mode=True)
                    alpha_pred_os1[weight_os1 == 0] = alpha_pred_os4[weight_os1
                                                                     == 0]
            else:
                weight_os4 = utils.get_unknown_tensor_from_pred(
                    alpha_pred_os8,
                    rand_width=CONFIG.model.self_refine_width1,
                    train_mode=True)
                alpha_pred_os4[weight_os4 == 0] = alpha_pred_os8[weight_os4 ==
                                                                 0]
                weight_os1 = utils.get_unknown_tensor_from_pred(
                    alpha_pred_os4,
                    rand_width=CONFIG.model.self_refine_width2,
                    train_mode=True)
                alpha_pred_os1[weight_os1 == 0] = alpha_pred_os4[weight_os1 ==
                                                                 0]
            """===== Calculate Loss ====="""
            if self.train_config.rec_weight > 0:
                self.loss_dict['rec'] = (self.regression_loss(alpha_pred_os1, alpha, loss_type='l1', weight=weight_os1) * 2 +\
                 self.regression_loss(alpha_pred_os4, alpha, loss_type='l1', weight=weight_os4) * 1 +\
                  self.regression_loss(alpha_pred_os8, alpha, loss_type='l1', weight=weight_os8) * 1) / 5.0 * self.train_config.rec_weight

            if self.train_config.comp_weight > 0:
                self.loss_dict['comp'] = (self.composition_loss(alpha_pred_os1, fg_norm, bg_norm, image, weight=weight_os1) * 2 +\
                 self.composition_loss(alpha_pred_os4, fg_norm, bg_norm, image, weight=weight_os4) * 1 +\
                  self.composition_loss(alpha_pred_os8, fg_norm, bg_norm, image, weight=weight_os8) * 1) / 5.0 * self.train_config.comp_weight

            if self.train_config.lap_weight > 0:
                self.loss_dict['lap'] = (self.lap_loss(logit=alpha_pred_os1, target=alpha, gauss_filter=self.gauss_filter, loss_type='l1', weight=weight_os1) * 2 +\
                 self.lap_loss(logit=alpha_pred_os4, target=alpha, gauss_filter=self.gauss_filter, loss_type='l1', weight=weight_os4) * 1 +\
                  self.lap_loss(logit=alpha_pred_os8, target=alpha, gauss_filter=self.gauss_filter, loss_type='l1', weight=weight_os8) * 1) / 5.0 * self.train_config.lap_weight

            for loss_key in self.loss_dict.keys():
                if self.loss_dict[loss_key] is not None and loss_key in [
                        'rec', 'comp', 'lap'
                ]:
                    loss += self.loss_dict[loss_key]
            """===== Back Propagate ====="""
            self.reset_grad()

            loss.backward()
            """===== Clip Large Gradient ====="""
            if self.train_config.clip_grad:
                if moving_max_grad == 0:
                    moving_max_grad = nn_utils.clip_grad_norm_(
                        self.G.parameters(), 1e+6)
                    max_grad = moving_max_grad
                else:
                    max_grad = nn_utils.clip_grad_norm_(
                        self.G.parameters(), 2 * moving_max_grad)
                    moving_max_grad = moving_max_grad * moving_grad_moment + max_grad * (
                        1 - moving_grad_moment)
            """===== Update Parameters ====="""
            self.G_optimizer.step()
            """===== Write Log and Tensorboard ====="""
            # stdout log
            if step % self.log_config.logging_step == 0:
                # reduce losses from GPUs
                if CONFIG.dist:
                    self.loss_dict = utils.reduce_tensor_dict(self.loss_dict,
                                                              mode='mean')
                    loss = utils.reduce_tensor(loss)
                # create logging information
                for loss_key in self.loss_dict.keys():
                    if self.loss_dict[loss_key] is not None:
                        log_info += loss_key.upper() + ": {:.4f}, ".format(
                            self.loss_dict[loss_key])

                self.logger.debug(
                    "Image tensor shape: {}. Trimap tensor shape: {}".format(
                        image.shape, trimap.shape))
                log_info = "[{}/{}], ".format(
                    step, self.train_config.total_step) + log_info
                log_info += "lr: {:6f}".format(cur_G_lr)
                self.logger.info(log_info)

                # tensorboard
                if step % self.log_config.tensorboard_step == 0 or step == start:  # and step > start:
                    self.tb_logger.scalar_summary('Loss', loss, step)

                    # detailed losses
                    for loss_key in self.loss_dict.keys():
                        if self.loss_dict[loss_key] is not None:
                            self.tb_logger.scalar_summary(
                                'Loss_' + loss_key.upper(),
                                self.loss_dict[loss_key], step)

                    self.tb_logger.scalar_summary('LearnRate', cur_G_lr, step)

                    if self.train_config.clip_grad:
                        self.tb_logger.scalar_summary('Moving_Max_Grad',
                                                      moving_max_grad, step)
                        self.tb_logger.scalar_summary('Max_Grad', max_grad,
                                                      step)
            """===== TEST ====="""
            if ((step % self.train_config.val_step) == 0 or step
                    == self.train_config.total_step):  # and step > start:
                self.G.eval()
                test_loss = 0
                log_info = ""

                self.test_loss_dict['mse'] = 0
                self.test_loss_dict['sad'] = 0
                for loss_key in self.loss_dict.keys():
                    if loss_key in self.test_loss_dict and self.loss_dict[
                            loss_key] is not None:
                        self.test_loss_dict[loss_key] = 0

                with torch.no_grad():
                    for image_dict in self.test_dataloader:
                        image, alpha, trimap, mask = image_dict[
                            'image'], image_dict['alpha'], image_dict[
                                'trimap'], image_dict['mask']
                        alpha_shape = image_dict['alpha_shape']
                        image = image.cuda()
                        alpha = alpha.cuda()
                        trimap = trimap.cuda()
                        mask = mask.cuda()

                        pred = self.G(image, mask)

                        alpha_pred_os1, alpha_pred_os4, alpha_pred_os8 = pred[
                            'alpha_os1'], pred['alpha_os4'], pred['alpha_os8']
                        alpha_pred = alpha_pred_os8.clone().detach()
                        weight_os4 = utils.get_unknown_tensor_from_pred(
                            alpha_pred,
                            rand_width=CONFIG.model.self_refine_width1,
                            train_mode=False)
                        alpha_pred[weight_os4 > 0] = alpha_pred_os4[
                            weight_os4 > 0]
                        weight_os1 = utils.get_unknown_tensor_from_pred(
                            alpha_pred,
                            rand_width=CONFIG.model.self_refine_width2,
                            train_mode=False)
                        alpha_pred[weight_os1 > 0] = alpha_pred_os1[
                            weight_os1 > 0]

                        h, w = alpha_shape
                        alpha_pred = alpha_pred[..., :h, :w]
                        trimap = trimap[..., :h, :w]

                        weight = utils.get_unknown_tensor(trimap)
                        weight[...] = 1

                        # value of MSE/SAD here is different from test.py and matlab version
                        self.test_loss_dict['mse'] += self.mse(
                            alpha_pred, alpha, weight)
                        self.test_loss_dict['sad'] += self.sad(
                            alpha_pred, alpha, weight)

                        if self.train_config.rec_weight > 0:
                            self.test_loss_dict['rec'] += self.regression_loss(alpha_pred, alpha, weight=weight) \
                                                          * self.train_config.rec_weight

                # reduce losses from GPUs
                if CONFIG.dist:
                    self.test_loss_dict = utils.reduce_tensor_dict(
                        self.test_loss_dict, mode='mean')
                """===== Write Log and Tensorboard ====="""
                # stdout log
                for loss_key in self.test_loss_dict.keys():
                    if self.test_loss_dict[loss_key] is not None:
                        self.test_loss_dict[loss_key] /= len(
                            self.test_dataloader)
                        # logging
                        log_info += loss_key.upper() + ": {:.4f} ".format(
                            self.test_loss_dict[loss_key])
                        self.tb_logger.scalar_summary(
                            'Loss_' + loss_key.upper(),
                            self.test_loss_dict[loss_key],
                            step,
                            phase='test')

                        if loss_key in ['rec']:
                            test_loss += self.test_loss_dict[loss_key]

                self.logger.info("TEST: LOSS: {:.4f} ".format(test_loss) +
                                 log_info)
                self.tb_logger.scalar_summary('Loss',
                                              test_loss,
                                              step,
                                              phase='test')

                # if self.model_config.trimap_channel == 3:
                #     trimap = trimap.argmax(dim=1, keepdim=True)
                # alpha_pred[trimap==2] = 1
                # alpha_pred[trimap==0] = 0
                image_set = {
                    'image':
                    (utils.normalize_image(image[-1, ...]).data.cpu().numpy() *
                     255).astype(np.uint8),
                    'mask':
                    (mask[-1, ...].data.cpu().numpy() * 255).astype(np.uint8),
                    'alpha':
                    (alpha[-1, ...].data.cpu().numpy() * 255).astype(np.uint8),
                    'alpha_pred': (alpha_pred[-1, ...].data.cpu().numpy() *
                                   255).astype(np.uint8)
                }

                self.tb_logger.image_summary(image_set, step, phase='test')
                """===== Save Model ====="""
                if (step % self.log_config.checkpoint_step == 0 or step == self.train_config.total_step) \
                        and CONFIG.local_rank == 0 and (step > start):
                    self.logger.info(
                        'Saving the trained models from step {}...'.format(
                            iter))
                    self.save_model("latest_model", step, loss)
                    if self.test_loss_dict['mse'] < self.best_loss:
                        self.best_loss = self.test_loss_dict['mse']
                        self.save_model("best_model", step, loss)

                torch.cuda.empty_cache()
Ejemplo n.º 20
0
            for i, data_dict in enumerate(data_loader):
                package = {'Epoch': epoch + 1}
                time = datetime.datetime.now()

                current_step += 1
                package.update({'Current_step': current_step})
                for k, v in data_dict.items():
                    data_dict.update({k: v.to(device)})

                package.update(criterion(A, G, data_dict, current_step))
                A_optim.zero_grad()
                package['total_A_loss'].backward()
                A_optim.step()

                if current_step % n_critics == 0:
                    G_optim.zero_grad()
                    package['total_G_loss'].backward()
                    G_optim.step()
                package.update({'running_time': str(datetime.datetime.now() - time)})
                manager(package)
                del package

                if opt.debug:
                    break

            manager.layer_magnitude(G, epoch + 1)
            if epoch > epoch_decay:
                lr = update_lr(lr, opt.n_epochs - opt.epoch_decay, A_optim, G_optim)

    print("Total time taken: ", datetime.datetime.now() - start_time)